Handle SAME_LOWER and SAME_UPPER.

This commit is contained in:
Doru Bercea 2020-01-21 20:39:11 -05:00
parent ec9e023f04
commit 169236a8fc
1 changed files with 31 additions and 22 deletions

View File

@ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() {
void ONNXTransposeOp::inferShapes() { void ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!getOperand().getType().isa<RankedTensorType>()) if (!getOperand().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked."); return;
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
@ -464,6 +464,7 @@ void ONNXConvNoBiasOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return;
auto dataTy = getOperand(0).getType().cast<RankedTensorType>(); auto dataTy = getOperand(0).getType().cast<RankedTensorType>();
auto weightTy = getOperand(1).getType().cast<RankedTensorType>(); auto weightTy = getOperand(1).getType().cast<RankedTensorType>();
auto dataShape = dataTy.getShape(); auto dataShape = dataTy.getShape();
@ -492,34 +493,37 @@ void ONNXConvNoBiasOp::inferShapes() {
// Insert number of filters being applied (number of output channels). // Insert number of filters being applied (number of output channels).
dims.emplace_back(weightShape[0]); dims.emplace_back(weightShape[0]);
// Spatial dimensions are computed using the formula: // Spatial dimensions of the output are computed using the formula:
// //
// dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1 // dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1
// //
SmallVector<int64_t, 2> spatialDims; SmallVector<int64_t, 2> outSpatialDims;
// Number of spatial dimensions. // Number of spatial dimensions.
int32_t nDims = dataShape.size() - 2; int32_t nDims = dataShape.size() - 2;
// Initialize dimenions based on the input spatial dimensions. // Initialize dimenions based on the input spatial dimensions.
for (int i = 2; i < dataShape.size(); ++i) for (int i = 2; i < dataShape.size(); ++i)
spatialDims.emplace_back(dataShape[i]); outSpatialDims.emplace_back(dataShape[i]);
// Use kernel_shape attribute if present otherwise use size from weight // Use kernel_shape attribute if present otherwise use size from weight
// argument. // argument.
if (auto kernel_shape = getAttrOfType<ArrayAttr>( SmallVector<int64_t, 2> kernelDims;
if (auto kernelShape = getAttrOfType<ArrayAttr>(
ONNXConvOp::getKernelShapeAttrName())) { ONNXConvOp::getKernelShapeAttrName())) {
if (kernel_shape.getValue().size() != nDims) if (kernelShape.getValue().size() != nDims)
emitError("kernel_shape length incompatible with spatial dimensions."); emitError("kernel_shape length incompatible with spatial dimensions.");
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i)
int64_t kernelDim = kernelDims[i] =
(kernel_shape.getValue()[i]).cast<IntegerAttr>().getInt(); (kernelShape.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] -= kernelDim;
}
} else { } else {
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
spatialDims[i] -= weightShape[i + 2]; kernelDims[i] = weightShape[i + 2];
} }
// Subtract kernel dimensions from input data dimensions.
for (int i = 0; i < nDims; ++i)
outSpatialDims[i] -= kernelDims[i];
// Add padding information. // Add padding information.
if (autoPad == "NOTSET") { if (autoPad == "NOTSET") {
// Use pads to to determine the padding. If attribute is not // Use pads to to determine the padding. If attribute is not
@ -533,18 +537,23 @@ void ONNXConvNoBiasOp::inferShapes() {
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
// Padding for beginning of axis. // Padding for beginning of axis.
int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt(); int32_t p = (pads.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] += p; outSpatialDims[i] += p;
// Padding for end of axis. // Padding for end of axis.
p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt(); p = (pads.getValue()[i + nDims]).cast<IntegerAttr>().getInt();
spatialDims[i] += p; outSpatialDims[i] += p;
} }
} }
} else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") {
// Pad input so that output size matches input size.
// Each spatial dimension needs to be padded by:
//
// ( K - 1 ) / 2
//
// where K is a kernel spatial dimension.
for (int i = 0; i < nDims; ++i)
outSpatialDims[i] += floor((kernelDims[i] - 1) / 2);
} else if (autoPad == "VALID") { } else if (autoPad == "VALID") {
// TODO // No padding
} else if (autoPad == "SAME_UPPER") {
// TODO
} else if (autoPad == "SAME_LOWER") {
// TODO
} else { } else {
emitError("Unexpected attribute value for auto_pad."); emitError("Unexpected attribute value for auto_pad.");
} }
@ -557,14 +566,14 @@ void ONNXConvNoBiasOp::inferShapes() {
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
int64_t stride = int64_t stride =
(strides.getValue()[i]).cast<IntegerAttr>().getInt(); (strides.getValue()[i]).cast<IntegerAttr>().getInt();
spatialDims[i] = floor(spatialDims[i] / stride); outSpatialDims[i] = floor(outSpatialDims[i] / stride);
} }
} }
for (int i = 0; i < nDims; ++i) for (int i = 0; i < nDims; ++i)
spatialDims[i] += 1; outSpatialDims[i] += 1;
dims.append(spatialDims.begin(), spatialDims.end()); dims.append(outSpatialDims.begin(), outSpatialDims.end());
getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); getResult().setType(RankedTensorType::get(dims, dataTy.getElementType()));
} }