Handle SAME_LOWER and SAME_UPPER.

This commit is contained in:
Doru Bercea 2020-01-21 20:39:11 -05:00
parent ec9e023f04
commit 169236a8fc
1 changed files with 31 additions and 22 deletions

View File

@ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() {
void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!getOperand().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked.");
return;
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
@ -464,6 +464,7 @@ void ONNXConvNoBiasOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape();
@ -492,34 +493,37 @@ void ONNXConvNoBiasOp::inferShapes() {
// Insert number of filters being applied (number of output channels).
dims.emplace_back(weightShape[0]);
// Spatial dimensions are computed using the formula:
// Spatial dimensions of the output are computed using the formula:
//
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
//
SmallVector<int64_t, 2> spatialDims;
SmallVector<int64_t, 2> outSpatialDims;
// Number of spatial dimensions.
int32_t nDims = dataShape.size() - 2;
// Initialize dimenions based on the input spatial dimensions.
for (int i = 2; i < dataShape.size(); ++i)
spatialDims.emplace_back(dataShape[i]);
outSpatialDims.emplace_back(dataShape[i]);
// Use kernel_shape attribute if present otherwise use size from weight
// argument.
if (auto kernel_shape = getAttrOfType<ArrayAttr>(
SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = getAttrOfType<ArrayAttr>(
ONNXConvOp::getKernelShapeAttrName())) {
if (kernel_shape.getValue().size() != nDims)
if (kernelShape.getValue().size() != nDims)
emitError("kernel_shape length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) {
int64_t kernelDim =
(kernel_shape.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] -= kernelDim;
}
for (int i = 0; i < nDims; ++i)
kernelDims[i] =
(kernelShape.getValue()[i]).cast<IntegerAttr>().getInt();
} else {
for (int i = 0; i < nDims; ++i)
spatialDims[i] -= weightShape[i + 2];
kernelDims[i] = weightShape[i + 2];
}
// Subtract kernel dimensions from input data dimensions.
for (int i = 0; i < nDims; ++i)
outSpatialDims[i] -= kernelDims[i];
// Add padding information.
if (autoPad == "NOTSET") {
// Use pads to to determine the padding. If attribute is not
@ -533,18 +537,23 @@ void ONNXConvNoBiasOp::inferShapes() {
for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] += p;
outSpatialDims[i] += p;
// Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
spatialDims[i] += p;
outSpatialDims[i] += p;
}
}
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// Pad input so that output size matches input size.
// Each spatial dimension needs to be padded by:
//
// ( K - 1 ) / 2
//
// where K is a kernel spatial dimension.
for (int i = 0; i < nDims; ++i)
outSpatialDims[i] += floor((kernelDims[i] - 1) / 2);
} else if (autoPad == "VALID") {
// TODO
} else if (autoPad == "SAME_UPPER") {
// TODO
} else if (autoPad == "SAME_LOWER") {
// TODO
// No padding
} else {
emitError("Unexpected attribute value for auto_pad.");
}
@ -557,14 +566,14 @@ void ONNXConvNoBiasOp::inferShapes() {
for (int i = 0; i < nDims; ++i) {
int64_t stride =
(strides.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] = floor(spatialDims[i] / stride);
outSpatialDims[i] = floor(outSpatialDims[i] / stride);
}
}
for (int i = 0; i < nDims; ++i)
spatialDims[i] += 1;
outSpatialDims[i] += 1;
dims.append(spatialDims.begin(), spatialDims.end());
dims.append(outSpatialDims.begin(), outSpatialDims.end());
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
}