From ab8e2f9a1bba1b7081e4b6755b89881029d8ad5a Mon Sep 17 00:00:00 2001 From: Doru Bercea Date: Mon, 20 Jan 2020 11:16:27 -0500 Subject: [PATCH] Add verifier to check for required attributes. --- src/dialect/onnx/onnx.td | 4 +- src/dialect/onnx/onnx_ops.cpp | 88 +++++++++++++++++++++++++++++++ src/dialect/onnx/onnxop.inc | 9 ++++ src/pass/shape_inference_pass.cpp | 3 +- 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/src/dialect/onnx/onnx.td b/src/dialect/onnx/onnx.td index 29733f7..710f3af 100644 --- a/src/dialect/onnx/onnx.td +++ b/src/dialect/onnx/onnx.td @@ -104,7 +104,7 @@ def ONNXGemmNoBiasOp: ONNX_Op<"GemmNoBias", } def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", - [NoSideEffect]> { + [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "ONNX Conv operation with no Bias operand."; let description = [{ "The convolution operator consumes an input tensor and a filter, and" @@ -112,6 +112,8 @@ def ONNXConvNoBiasOp:ONNX_Op<"ConvNoBias", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); + + let verifier = [{ return ::verify(*this); }]; } def ONNXMaxPoolSingleOutOp: ONNX_Op<"MaxPoolSingleOut", diff --git a/src/dialect/onnx/onnx_ops.cpp b/src/dialect/onnx/onnx_ops.cpp index 53e463d..44332b5 100644 --- a/src/dialect/onnx/onnx_ops.cpp +++ b/src/dialect/onnx/onnx_ops.cpp @@ -448,6 +448,94 @@ LogicalResult verify(ONNXTransposeOp op) { return success(); } +//===----------------------------------------------------------------------===// + +// Conv + +// For this operation, we define the attributes once in the original Conv +// operation class. There is no need to redefine the attribute names for the +// other classes based on Conv. +void ONNXConvNoBiasOp::inferShapes() { + // Generic shape for data input X and weight tensor W: + // X: (N x C x D1 x D2 ... x Dn) + // W: (M x C/group x k1 x k2 x ... x kn) + + // Cannot infer shape if no shape exists. + if (!getOperand(0).getType().isa() || + !getOperand(1).getType().isa()) + return; + auto dataTy = getOperand(0)->getType().cast(); + auto weightTy = getOperand(1)->getType().cast(); + auto dataShape = dataTy.getShape(); + auto weightShape = weightTy.getShape(); + + if (dataShape.size() != weightShape.size()) + emitError("ConvNoBias: weight size not compatible with data size."); + + // Group is a required attribute and should have default value of 1. + int64_t group = getAttrOfType( + ONNXConvOp::getGroupAttrName()).getInt(); + if (!group) + emitError("ConvNoBias: group attribute missing."); + + // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. + if (dataShape[1] != (weightShape[1] * group)) + emitError("ConvNoBias: channel dimension mismatch."); + + // Required attributes. + auto auto_pad = getAttrOfType( + ONNXConvOp::getAutoPadAttrName()); + auto pads = getAttrOfType( + ONNXConvOp::getPadsAttrName()); + + SmallVector dims; + // Insert batch size. + dims.emplace_back(dataShape[0]); + // Insert number of filters being applied (number of output channels). + dims.emplace_back(weightShape[0]); + + // // Compute the spatial dimensions. + // SmallVector spatialDims; + // // Number of spatial dimensions. + // int32_t nDims = dataTy.size() - 2; + // // Initialize dimenions based on the input and weight spatial dimensions. + // for (int i = 2; i < dataTy.size(); ++i) + // spatialDims.emplace_back(dataTy[i] - weightTy[i]); + // // Add padding information. + // if () { + // for (int i = 0; i < nDims; ++i) { + // // Padding for beginning of axis. + // int32_t p = (pads.getValue()[i]).cast().getInt(); + // spatialDims[i] += p; + // // Padding for end of axis. + // p = (pads.getValue()[i + nDims]).cast().getInt(); + // spatialDims[i] += p; + // } + // } else if () { + // // Attribute pads has not been provided. + // } + + getResult().setType(RankedTensorType::get(dims, dataTy.getElementType())); +} + +LogicalResult verify(ONNXConvNoBiasOp op) { + auto module = op.getParentOfType(); + if (!module) + op.emitError("expected to belong to a module"); + + auto autoPadAttr = op.getAttrOfType( + ONNXConvOp::getAutoPadAttrName()); + if (!autoPadAttr) + op.emitError("ONNXConvNoBiasOp: auto_pad attribute not specified."); + + auto groupAttr = + op.getAttrOfType(ONNXConvOp::getGroupAttrName()); + if (!groupAttr) + op.emitError("ONNXConvNoBiasOp: group attribute not specified."); + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/src/dialect/onnx/onnxop.inc b/src/dialect/onnx/onnxop.inc index e87a01a..3a54fa0 100644 --- a/src/dialect/onnx/onnxop.inc +++ b/src/dialect/onnx/onnxop.inc @@ -324,6 +324,15 @@ def ONNXConvOp:ONNX_Op<"Conv", }]; let arguments = (ins AnyTypeOf<[AnyMemRef, AnyTensor]>:$X, AnyTypeOf<[AnyMemRef, AnyTensor]>:$W, AnyTypeOf<[AnyMemRef, AnyTensor]>:$B); let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]>); + + let extraClassDeclaration = [{ + static StringRef getAutoPadAttrName() { return "auto_pad"; } + static StringRef getDilationsAttrName() { return "dilations"; } + static StringRef getGroupAttrName() { return "group"; } + static StringRef getKernelShapeAttrName() { return "kernel_shape"; } + static StringRef getPadsAttrName() { return "pads"; } + static StringRef getStridesAttrName() { return "strides"; } + }]; } def ONNXConvIntegerOp:ONNX_Op<"ConvInteger", diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 3226f16..5239904 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -117,7 +117,8 @@ public: op->getName().getStringRef() != "onnx.GemmNoBias" && op->getName().getStringRef() != "onnx.Reshape" && op->getName().getStringRef() != "onnx.Transpose" && - op->getName().getStringRef() != "onnx.Softmax") + op->getName().getStringRef() != "onnx.Softmax" && + op->getName().getStringRef() != "onnx.ConvNoBias") return false; return llvm::any_of(op->getResultTypes(), [](Type result_type) { return !result_type.isa();