diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 2b65c15..e597c9a 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -412,7 +412,7 @@ void ONNXReshapeOp::inferShapes() { void ONNXTransposeOp::inferShapes() { // Cannot infer shape if no shape exists. if (!getOperand().getType().isa()) - emitError("Shape tensor not ranked."); + return; // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). @@ -464,6 +464,7 @@ void ONNXConvNoBiasOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) return; + auto dataTy = getOperand(0).getType().cast(); auto weightTy = getOperand(1).getType().cast(); auto dataShape = dataTy.getShape(); @@ -492,34 +493,37 @@ void ONNXConvNoBiasOp::inferShapes() { // Insert number of filters being applied (number of output channels). dims.emplace_back(weightShape[0]); - // Spatial dimensions are computed using the formula: + // Spatial dimensions of the output are computed using the formula: // // dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1 // - SmallVector spatialDims; + SmallVector outSpatialDims; // Number of spatial dimensions. int32_t nDims = dataShape.size() - 2; // Initialize dimenions based on the input spatial dimensions. for (int i = 2; i < dataShape.size(); ++i) - spatialDims.emplace_back(dataShape[i]); + outSpatialDims.emplace_back(dataShape[i]); // Use kernel_shape attribute if present otherwise use size from weight // argument. - if (auto kernel_shape = getAttrOfType( + SmallVector kernelDims; + if (auto kernelShape = getAttrOfType( ONNXConvOp::getKernelShapeAttrName())) { - if (kernel_shape.getValue().size() != nDims) + if (kernelShape.getValue().size() != nDims) emitError("kernel_shape length incompatible with spatial dimensions."); - for (int i = 0; i < nDims; ++i) { - int64_t kernelDim = - (kernel_shape.getValue()[i]).cast().getInt(); - spatialDims[i] -= kernelDim; - } + for (int i = 0; i < nDims; ++i) + kernelDims[i] = + (kernelShape.getValue()[i]).cast().getInt(); } else { for (int i = 0; i < nDims; ++i) - spatialDims[i] -= weightShape[i + 2]; + kernelDims[i] = weightShape[i + 2]; } + // Subtract kernel dimensions from input data dimensions. + for (int i = 0; i < nDims; ++i) + outSpatialDims[i] -= kernelDims[i]; + // Add padding information. if (autoPad == "NOTSET") { // Use pads to to determine the padding. If attribute is not @@ -533,18 +537,23 @@ void ONNXConvNoBiasOp::inferShapes() { for (int i = 0; i < nDims; ++i) { // Padding for beginning of axis. int32_t p = (pads.getValue()[i]).cast().getInt(); - spatialDims[i] += p; + outSpatialDims[i] += p; // Padding for end of axis. p = (pads.getValue()[i + nDims]).cast().getInt(); - spatialDims[i] += p; + outSpatialDims[i] += p; } } + } else if (autoPad == "SAME_UPPER" || autoPad == "SAME_LOWER") { + // Pad input so that output size matches input size. + // Each spatial dimension needs to be padded by: + // + // ( K - 1 ) / 2 + // + // where K is a kernel spatial dimension. + for (int i = 0; i < nDims; ++i) + outSpatialDims[i] += floor((kernelDims[i] - 1) / 2); } else if (autoPad == "VALID") { - // TODO - } else if (autoPad == "SAME_UPPER") { - // TODO - } else if (autoPad == "SAME_LOWER") { - // TODO + // No padding } else { emitError("Unexpected attribute value for auto_pad."); } @@ -557,14 +566,14 @@ void ONNXConvNoBiasOp::inferShapes() { for (int i = 0; i < nDims; ++i) { int64_t stride = (strides.getValue()[i]).cast().getInt(); - spatialDims[i] = floor(spatialDims[i] / stride); + outSpatialDims[i] = floor(outSpatialDims[i] / stride); } } for (int i = 0; i < nDims; ++i) - spatialDims[i] += 1; + outSpatialDims[i] += 1; - dims.append(spatialDims.begin(), spatialDims.end()); + dims.append(outSpatialDims.begin(), outSpatialDims.end()); getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); }