diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index cb9407f..2b65c15 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -469,52 +469,102 @@ void ONNXConvNoBiasOp::inferShapes() { auto dataShape = dataTy.getShape(); auto weightShape = weightTy.getShape(); + // Check that shape of weight and data have same length. if (dataShape.size() != weightShape.size()) - emitError("ConvNoBias: weight size not compatible with data size."); + emitError("Weight size not compatible with data size."); + // Required attribute auto_pad defaults to NOTSET. + auto autoPad = getAttrOfType( + ONNXConvOp::getAutoPadAttrName()).getValue(); // Group is a required attribute and should have default value of 1. int64_t group = getAttrOfType( ONNXConvOp::getGroupAttrName()).getInt(); - if (!group) - emitError("ConvNoBias: group attribute missing."); - // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (dataShape[1] != (weightShape[1] * group)) - emitError("ConvNoBias: channel dimension mismatch."); - - // Required attributes. - auto auto_pad = getAttrOfType( - ONNXConvOp::getAutoPadAttrName()); - auto pads = getAttrOfType( - ONNXConvOp::getPadsAttrName()); + emitError("Channel dimension mismatch."); + // First two output dimensions consist of the number of batches and the + // number of kernels being applied. + // SmallVector dims; // Insert batch size. dims.emplace_back(dataShape[0]); // Insert number of filters being applied (number of output channels). dims.emplace_back(weightShape[0]); - // // Compute the spatial dimensions. - // SmallVector spatialDims; - // // Number of spatial dimensions. - // int32_t nDims = dataTy.size() - 2; - // // Initialize dimenions based on the input and weight spatial dimensions. - // for (int i = 2; i < dataTy.size(); ++i) - // spatialDims.emplace_back(dataTy[i] - weightTy[i]); - // // Add padding information. - // if () { - // for (int i = 0; i < nDims; ++i) { - // // Padding for beginning of axis. - // int32_t p = (pads.getValue()[i]).cast().getInt(); - // spatialDims[i] += p; - // // Padding for end of axis. - // p = (pads.getValue()[i + nDims]).cast().getInt(); - // spatialDims[i] += p; - // } - // } else if () { - // // Attribute pads has not been provided. - // } + // Spatial dimensions are computed using the formula: + // + // dim = (inputDim - kernelDim + startPadding + endPadding) / stride + 1 + // + SmallVector spatialDims; + // Number of spatial dimensions. + int32_t nDims = dataShape.size() - 2; + // Initialize dimenions based on the input spatial dimensions. + for (int i = 2; i < dataShape.size(); ++i) + spatialDims.emplace_back(dataShape[i]); + + // Use kernel_shape attribute if present otherwise use size from weight + // argument. + if (auto kernel_shape = getAttrOfType( + ONNXConvOp::getKernelShapeAttrName())) { + if (kernel_shape.getValue().size() != nDims) + emitError("kernel_shape length incompatible with spatial dimensions."); + for (int i = 0; i < nDims; ++i) { + int64_t kernelDim = + (kernel_shape.getValue()[i]).cast().getInt(); + spatialDims[i] -= kernelDim; + } + } else { + for (int i = 0; i < nDims; ++i) + spatialDims[i] -= weightShape[i + 2]; + } + + // Add padding information. + if (autoPad == "NOTSET") { + // Use pads to to determine the padding. If attribute is not + // present then pads is considered to be all zeros (no padding). + if (auto pads = getAttrOfType( + ONNXConvOp::getPadsAttrName())) { + // pads consists of two entries for each spatial axis. + if (pads.getValue().size() != 2 * nDims) + emitError("pads size is not twice the spatial size."); + + for (int i = 0; i < nDims; ++i) { + // Padding for beginning of axis. + int32_t p = (pads.getValue()[i]).cast().getInt(); + spatialDims[i] += p; + // Padding for end of axis. + p = (pads.getValue()[i + nDims]).cast().getInt(); + spatialDims[i] += p; + } + } + } else if (autoPad == "VALID") { + // TODO + } else if (autoPad == "SAME_UPPER") { + // TODO + } else if (autoPad == "SAME_LOWER") { + // TODO + } else { + emitError("Unexpected attribute value for auto_pad."); + } + + // Strides + if (auto strides = getAttrOfType( + ONNXConvOp::getStridesAttrName())) { + if (strides.getValue().size() != nDims) + emitError("strides length incompatible with spatial dimensions."); + for (int i = 0; i < nDims; ++i) { + int64_t stride = + (strides.getValue()[i]).cast().getInt(); + spatialDims[i] = floor(spatialDims[i] / stride); + } + } + + for (int i = 0; i < nDims; ++i) + spatialDims[i] += 1; + + dims.append(spatialDims.begin(), spatialDims.end()); getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); } @@ -526,12 +576,16 @@ LogicalResult verify(ONNXConvNoBiasOp op) { auto autoPadAttr = op.getAttrOfType( ONNXConvOp::getAutoPadAttrName()); if (!autoPadAttr) - op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified."); + op.emitError("auto_pad attribute not specified."); + if (autoPadAttr.getValue() != "NOTSET") + if (auto pads = op.getAttrOfType( + ONNXConvOp::getPadsAttrName())) + op.emitError("auto_pad and pads are both set."); auto groupAttr = op.getAttrOfType(ONNXConvOp::getGroupAttrName()); if (!groupAttr) - op.emitError("ONNXConvNoBiasOp: group attribute not specified."); + op.emitError("group attribute not specified."); return success(); }