Handle SAME_LOWER and SAME_UPPER.
This commit is contained in:
parent
ec9e023f04
commit
169236a8fc
|
@ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() {
|
|||
void ONNXTransposeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!getOperand().getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked.");
|
||||
return;
|
||||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
|
@ -464,6 +464,7 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
|
||||
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
auto dataShape = dataTy.getShape();
|
||||
|
@ -492,34 +493,37 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
// Insert number of filters being applied (number of output channels).
|
||||
dims.emplace_back(weightShape[0]);
|
||||
|
||||
// Spatial dimensions are computed using the formula:
|
||||
// Spatial dimensions of the output are computed using the formula:
|
||||
//
|
||||
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
|
||||
//
|
||||
SmallVector<int64_t, 2> spatialDims;
|
||||
SmallVector<int64_t, 2> outSpatialDims;
|
||||
// Number of spatial dimensions.
|
||||
int32_t nDims = dataShape.size() - 2;
|
||||
|
||||
// Initialize dimenions based on the input spatial dimensions.
|
||||
for (int i = 2; i < dataShape.size(); ++i)
|
||||
spatialDims.emplace_back(dataShape[i]);
|
||||
outSpatialDims.emplace_back(dataShape[i]);
|
||||
|
||||
// Use kernel_shape attribute if present otherwise use size from weight
|
||||
// argument.
|
||||
if (auto kernel_shape = getAttrOfType<ArrayAttr>(
|
||||
SmallVector<int64_t, 2> kernelDims;
|
||||
if (auto kernelShape = getAttrOfType<ArrayAttr>(
|
||||
ONNXConvOp::getKernelShapeAttrName())) {
|
||||
if (kernel_shape.getValue().size() != nDims)
|
||||
if (kernelShape.getValue().size() != nDims)
|
||||
emitError("kernel_shape length incompatible with spatial dimensions.");
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
int64_t kernelDim =
|
||||
(kernel_shape.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] -= kernelDim;
|
||||
}
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
kernelDims[i] =
|
||||
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
} else {
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
spatialDims[i] -= weightShape[i + 2];
|
||||
kernelDims[i] = weightShape[i + 2];
|
||||
}
|
||||
|
||||
// Subtract kernel dimensions from input data dimensions.
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
outSpatialDims[i] -= kernelDims[i];
|
||||
|
||||
// Add padding information.
|
||||
if (autoPad == "NOTSET") {
|
||||
// Use pads to to determine the padding. If attribute is not
|
||||
|
@ -533,18 +537,23 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
for (int i = 0; i < nDims; ++i) {
|
||||
// Padding for beginning of axis.
|
||||
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] += p;
|
||||
outSpatialDims[i] += p;
|
||||
// Padding for end of axis.
|
||||
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] += p;
|
||||
outSpatialDims[i] += p;
|
||||
}
|
||||
}
|
||||
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
|
||||
// Pad input so that output size matches input size.
|
||||
// Each spatial dimension needs to be padded by:
|
||||
//
|
||||
// ( K - 1 ) / 2
|
||||
//
|
||||
// where K is a kernel spatial dimension.
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
outSpatialDims[i] += floor((kernelDims[i] - 1) / 2);
|
||||
} else if (autoPad == "VALID") {
|
||||
// TODO
|
||||
} else if (autoPad == "SAME_UPPER") {
|
||||
// TODO
|
||||
} else if (autoPad == "SAME_LOWER") {
|
||||
// TODO
|
||||
// No padding
|
||||
} else {
|
||||
emitError("Unexpected attribute value for auto_pad.");
|
||||
}
|
||||
|
@ -557,14 +566,14 @@ void ONNXConvNoBiasOp::inferShapes() {
|
|||
for (int i = 0; i < nDims; ++i) {
|
||||
int64_t stride =
|
||||
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
|
||||
spatialDims[i] = floor(spatialDims[i] / stride);
|
||||
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < nDims; ++i)
|
||||
spatialDims[i] += 1;
|
||||
outSpatialDims[i] += 1;
|
||||
|
||||
dims.append(spatialDims.begin(), spatialDims.end());
|
||||
dims.append(outSpatialDims.begin(), outSpatialDims.end());
|
||||
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue