diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 6b6a29b..562d736 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -151,9 +151,8 @@ static void processConvStrideParam(T *op, Optional kernelShape) { // Support function that computes default values for pads. // template -static void processConvPadParam(T *op, - ArrayRef inputShape, Optional kernelShape, - Optional stridesOpt, +static void processConvPadParam(T *op, ArrayRef inputShape, + Optional kernelShape, Optional stridesOpt, Optional dilationsOpt = llvm::None) { auto builder = mlir::Builder(op->getContext()); @@ -256,7 +255,7 @@ static void processConvTypeParams(T *op, Value inputOperand) { processConvDilationParam(op, kernelShape); auto dilationsOpt = op->dilations(); - // Strides. + // Strides. processConvStrideParam(op, kernelShape); auto stridesOpt = op->strides(); @@ -341,219 +340,271 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location, // Exp /// Infer the output shape of the ONNXExpOp. This method is required by the /// shape inference interface. -void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXExpOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Tanh /// Infer the output shape of the ONNXTanhOp. This method is required by the /// shape inference interface. -void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXTanhOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Sinh /// Infer the output shape of the ONNXSinhOp. This method is required by the /// shape inference interface. -void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXSinhOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Cosh /// Infer the output shape of the ONNXCoshOp. This method is required by the /// shape inference interface. -void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXCoshOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Cos /// Infer the output shape of the ONNXCosOp. This method is required by the /// shape inference interface. -void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXCosOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Log /// Infer the output shape of the ONNXLogOp. This method is required by the /// shape inference interface. -void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXLogOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // HardSigmoid /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// the shape inference interface. -void ONNXHardSigmoidOp::inferShapes() { +bool ONNXHardSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // Sigmoid /// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// shape inference interface. -void ONNXSigmoidOp::inferShapes() { +bool ONNXSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // Elu /// Infer the output shape of the ONNXEluOp. This method is required by the /// shape inference interface. -void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXEluOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Relu /// Infer the output shape of the ONNXReluOp. This method is required by the /// shape inference interface. -void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXReluOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // LeakyRelu /// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// the shape inference interface. -void ONNXLeakyReluOp::inferShapes() { +bool ONNXLeakyReluOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // Selu /// Infer the output shape of the ONNXSeluOp. This method is required by /// the shape inference interface. -void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXSeluOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Reciprocal /// Infer the output shape of the ONNXReciprocalOp. This method is required by /// the shape inference interface. -void ONNXReciprocalOp::inferShapes() { +bool ONNXReciprocalOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // Softmax /// Infer the output shape of the ONNXSoftmaxOp. This method is required by /// the shape inference interface. -void ONNXSoftmaxOp::inferShapes() { +bool ONNXSoftmaxOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // Softplus /// Infer the output shape of the ONNXSoftplusOp. This method is required by /// the shape inference interface. -void ONNXSoftplusOp::inferShapes() { +bool ONNXSoftplusOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // Softsign /// Infer the output shape of the ONNXSoftsignOp. This method is required by /// the shape inference interface. -void ONNXSoftsignOp::inferShapes() { +bool ONNXSoftsignOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // Sqrt /// Infer the output shape of the ONNXSqrtOp. This method is required by /// the shape inference interface. -void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXSqrtOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Sign /// Infer the output shape of the ONNXSignOp. This method is required by /// the shape inference interface. -void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXSignOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Abs /// Infer the output shape of the ONNXAbsOp. This method is required by the /// shape inference interface. -void ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); } +bool ONNXAbsOp::inferShapes() { + getResult().setType(getOperand().getType()); + return true; +} //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the /// shape inference interface. -void ONNXAddOp::inferShapes() { +bool ONNXAddOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) - return; + !getOperand(1).getType().isa()) { + emitError("ONNXAddOp inferShapes failed"); + return false; + } auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return true; } //===----------------------------------------------------------------------===// // Mul /// Infer the output shape of the ONNXMulOp. This method is required by the /// shape inference interface. -void ONNXMulOp::inferShapes() { +bool ONNXMulOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) - return; + return false; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return true; } //===----------------------------------------------------------------------===// // Div /// Infer the output shape of the ONNXDivOp. This method is required by the /// shape inference interface. -void ONNXDivOp::inferShapes() { +bool ONNXDivOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) - return; + return false; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return true; } //===----------------------------------------------------------------------===// // Sub /// Infer the output shape of the ONNXSubOp. This method is required by the /// shape inference interface. -void ONNXSubOp::inferShapes() { +bool ONNXSubOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) - return; + return false; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return true; } //===----------------------------------------------------------------------===// // And /// Infer the output shape of the ONNXAndOp. This method is required by the /// shape inference interface. -void ONNXAndOp::inferShapes() { +bool ONNXAndOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) - return; + return false; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return true; } //===----------------------------------------------------------------------===// // Or /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. -void ONNXOrOp::inferShapes() { +bool ONNXOrOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) - return; + return false; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return true; } //===----------------------------------------------------------------------===// // Xor /// Infer the output shape of the ONNXXorOp. This method is required by the /// shape inference interface. -void ONNXXorOp::inferShapes() { +bool ONNXXorOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) - return; + return false; auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); + return true; } //===----------------------------------------------------------------------===// @@ -562,10 +613,10 @@ void ONNXXorOp::inferShapes() { // Sum /// Infer the output shape of the ONNXSumOp. This method is required by the /// shape inference interface. -void ONNXSumOp::inferShapes() { +bool ONNXSumOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) - return; + return false; } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -573,16 +624,17 @@ void ONNXSumOp::inferShapes() { resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); + return true; } //===----------------------------------------------------------------------===// // Max /// Infer the output shape of the ONNXMaxOp. This method is required by the /// shape inference interface. -void ONNXMaxOp::inferShapes() { +bool ONNXMaxOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) - return; + return false; } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -590,16 +642,17 @@ void ONNXMaxOp::inferShapes() { resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); + return true; } //===----------------------------------------------------------------------===// // Min /// Infer the output shape of the ONNXMinOp. This method is required by the /// shape inference interface. -void ONNXMinOp::inferShapes() { +bool ONNXMinOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { if (!getOperand(i).getType().cast()) - return; + return false; } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -607,25 +660,27 @@ void ONNXMinOp::inferShapes() { resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); + return true; } //===----------------------------------------------------------------------===// // Identity /// Infer the output shape of the ONNXIdentityOp. This method is required by the /// shape inference interface. -void ONNXIdentityOp::inferShapes() { +bool ONNXIdentityOp::inferShapes() { getResult().setType(getOperand().getType()); + return true; } //===----------------------------------------------------------------------===// // MatMul -void ONNXMatMulOp::inferShapes() { +bool ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. if (!A().getType().isa() || !B().getType().isa()) - return; + return false; auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); @@ -752,19 +807,20 @@ void ONNXMatMulOp::inferShapes() { } getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); + return true; } //===----------------------------------------------------------------------===// // Gemm -void ONNXGemmOp::inferShapes() { +bool ONNXGemmOp::inferShapes() { bool hasBias = !C().getType().isa(); // Cannot infer shape if no shape exists. if (!A().getType().isa() || !B().getType().isa() || (hasBias && !C().getType().isa())) - return; + return false; auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); @@ -796,17 +852,18 @@ void ONNXGemmOp::inferShapes() { dims.emplace_back(M); dims.emplace_back(N); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); + return true; } /// BatchNormalizationTestMode -void ONNXBatchNormalizationTestModeOp::inferShapes() { +bool ONNXBatchNormalizationTestModeOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa() || !scale().getType().isa() || !B().getType().isa() || !mean().getType().isa() || !var().getType().isa()) - return; + return false; auto inputTensorTy = X().getType().cast(); auto scaleTensorTy = scale().getType().cast(); @@ -845,6 +902,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() { // The output tensor of the same shape as the input. getResult().setType(X().getType()); + return true; } // TODO: @@ -855,7 +913,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() { // Reshape -void ONNXReshapeOp::inferShapes() { +bool ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. if (!shape().getType().isa()) emitError("Shape tensor not ranked"); @@ -875,7 +933,7 @@ void ONNXReshapeOp::inferShapes() { // Compute total number of elements. int64_t totalInputSize = 1; - for(auto inputDim : inputTensorTy.getShape()) + for (auto inputDim : inputTensorTy.getShape()) totalInputSize *= inputDim; // Check if second argument of ReshapeOp is a constant. @@ -891,7 +949,7 @@ void ONNXReshapeOp::inferShapes() { // Get dims from valueAttribute. auto valueIt = valueAttribute.getValues().begin(); - for (int i=0; i().getInt(); if (valueIt != valueAttribute.getValues().end()) @@ -900,7 +958,7 @@ void ONNXReshapeOp::inferShapes() { int64_t numberOfDynamicInputs = 0; int64_t totalKnownDimsSize = 1; int64_t dynamicValueIndex = -1; - for (int i=0; i()) - return; + return false; // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). @@ -951,62 +1010,67 @@ void ONNXTransposeOp::inferShapes() { } getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); + return true; } //===----------------------------------------------------------------------===// // ReduceMax -void ONNXReduceMaxOp::inferShapes() { +bool ONNXReduceMaxOp::inferShapes() { if (!getOperand().getType().isa()) { emitError("Shape tensor not ranked"); - return; + return false; } auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); + return true; } //===----------------------------------------------------------------------===// // ReduceMin -void ONNXReduceMinOp::inferShapes() { +bool ONNXReduceMinOp::inferShapes() { if (!getOperand().getType().isa()) { emitError("Shape tensor not ranked"); - return; + return false; } auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); + return true; } //===----------------------------------------------------------------------===// // ReduceProd -void ONNXReduceProdOp::inferShapes() { +bool ONNXReduceProdOp::inferShapes() { if (!getOperand().getType().isa()) { emitError("Shape tensor not ranked"); - return; + return false; } auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); + return true; } //===----------------------------------------------------------------------===// // ReduceSum -void ONNXReduceSumOp::inferShapes() { +bool ONNXReduceSumOp::inferShapes() { if (!getOperand().getType().isa()) { emitError("Shape tensor not ranked"); - return; + return false; } auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); + return true; } //===----------------------------------------------------------------------===// @@ -1022,7 +1086,7 @@ void ONNXReduceSumOp::inferShapes() { // - kernelShape: inferred from weight matrix if not defined by user; // - pads: set to proper value, 0 if not defined by user. -void ONNXConvOp::inferShapes() { +bool ONNXConvOp::inferShapes() { // Generic shape for data input X, weight tensor W, and optional bias B // X: (N x C x D1 x D2 ... x Dn) // W: (M x C/group x k1 x k2 x ... x kn) @@ -1034,7 +1098,7 @@ void ONNXConvOp::inferShapes() { if (!X().getType().isa() || !W().getType().isa() || (hasBias && !B().getType().isa())) - return; + return false; auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); @@ -1043,12 +1107,16 @@ void ONNXConvOp::inferShapes() { auto builder = mlir::Builder(this->getContext()); // Lowest supported convolution is a one dimensional convolution. - if (xShape.size() < 3) + if (xShape.size() < 3) { emitError("Data input shape must be at least (NxCxD1)"); + return false; + } // Check that shape of weight and data have same length. - if (xShape.size() != weightShape.size()) + if (xShape.size() != weightShape.size()) { emitError("Weight size not compatible with data size"); + return false; + } // Group is a required attribute and should have default value of 1. int64_t group = ONNXConvOp::group().getSExtValue(); @@ -1059,19 +1127,25 @@ void ONNXConvOp::inferShapes() { // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (xShape[1] != -1 && weightShape[1] != -1 && - xShape[1] != (weightShape[1] * group)) + xShape[1] != (weightShape[1] * group)) { emitError("Channel dimension mismatch"); + return false; + } // Check the size of bias. if (hasBias) { auto bTx = B().getType().cast(); auto bShape = bTx.getShape(); - if (bShape.size() != 1) + if (bShape.size() != 1) { emitError("bias should be one dimensional"); - if (bShape[0] != weightShape[0]) + return false; + } + if (bShape[0] != weightShape[0]) { emitError("bias should have same dimensions as weight's first dimension"); + return false; + } } - + // Note: the value of the group attribut only impacts the way the // computation is carried out and not the actual output size. @@ -1083,12 +1157,16 @@ void ONNXConvOp::inferShapes() { // argument. auto kernelShape = kernel_shape(); if (kernelShape.hasValue()) { - if (ArrayAttrSize(kernelShape) != spatialRank) + if (ArrayAttrSize(kernelShape) != spatialRank) { emitError("kernel_shape length incompatible with spatial dimensions"); + return false; + } // Have the right number of values, check them. for (int i = 0; i < spatialRank; ++i) - if (ArrayAttrIntVal(kernelShape, i) < 1) + if (ArrayAttrIntVal(kernelShape, i) < 1) { emitError("bad kernel_shape value"); + return false; + } } else { // Deduce shape from weight input. SmallVector defaultVals; @@ -1119,6 +1197,7 @@ void ONNXConvOp::inferShapes() { &outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); + return true; } //===----------------------------------------------------------------------===// @@ -1129,10 +1208,10 @@ void ONNXConvOp::inferShapes() { // - strides: set to 1 if not defined by user; // - pads: set to proper value, 0 if not defined by user. -void ONNXAveragePoolOp::inferShapes() { +bool ONNXAveragePoolOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa()) - return; + return false; // Get shape of input. auto xTy = X().getType().cast(); @@ -1163,6 +1242,7 @@ void ONNXAveragePoolOp::inferShapes() { llvm::None, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); + return true; } //===----------------------------------------------------------------------===// @@ -1173,10 +1253,10 @@ void ONNXAveragePoolOp::inferShapes() { // - dilations, strides: set to 1 if not defined by user; // - pads: set to proper value, 0 if not defined by user. -void ONNXMaxPoolSingleOutOp::inferShapes() { +bool ONNXMaxPoolSingleOutOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa()) - return; + return false; // Get shape of input. auto xTy = X().getType().cast(); @@ -1211,6 +1291,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() { dilationsOpt, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); + return true; } //===----------------------------------------------------------------------===// @@ -1234,7 +1315,7 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { // Have to non-negative constant if (p1 < 0 || p2 < 0) return (Type)NULL; - if (outputShape[i] != -1) + if (outputShape[i] != -1) outputShape[i] += p1 + p2; } @@ -1246,24 +1327,26 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { // PadConstantPad -void ONNXPadConstantPadOp::inferShapes() { +bool ONNXPadConstantPadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); if (outputType) { getResult().setType(outputType); + return true; } - return; + return false; } //===----------------------------------------------------------------------===// // PadConstantValuePad -void ONNXPadConstantValuePadOp::inferShapes() { +bool ONNXPadConstantValuePadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); if (outputType) { getResult().setType(outputType); - } - return; + return true; + } + return false; } void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state, @@ -1280,9 +1363,9 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state, // Unsqueeze -void ONNXUnsqueezeOp::inferShapes() { +bool ONNXUnsqueezeOp::inferShapes() { if (!data().getType().isa()) - return; + return false; auto operandTy = data().getType().cast(); int inRank = operandTy.getRank(); @@ -1299,11 +1382,14 @@ void ONNXUnsqueezeOp::inferShapes() { assert(axis >= -outRank && axis <= outRank - 1); if (std::find(axes.begin(), axes.end(), axis) == axes.end()) axes.emplace_back(axis); - else + else { emitError("Duplicated axes"); + return false; + } } } else { emitError("Axes attribute is required"); + return false; } SmallVector dims; @@ -1315,12 +1401,13 @@ void ONNXUnsqueezeOp::inferShapes() { } } getResult().setType(RankedTensorType::get(dims, operandTy.getElementType())); + return true; } //===----------------------------------------------------------------------===// // Constant -void ONNXConstantOp::inferShapes() { +bool ONNXConstantOp::inferShapes() { if ((sparse_value().hasValue() && value().hasValue()) || (!sparse_value().hasValue() && !value().hasValue())) emitError("Require exactly one of the two attributes, either value or " @@ -1332,6 +1419,7 @@ void ONNXConstantOp::inferShapes() { else valAttr = valueAttr().cast(); getResult().setType(valAttr.getType()); + return true; } //===----------------------------------------------------------------------===// diff --git a/src/Interface/ShapeInferenceInterface.td b/src/Interface/ShapeInferenceInterface.td index ca991a4..0d49090 100644 --- a/src/Interface/ShapeInferenceInterface.td +++ b/src/Interface/ShapeInferenceInterface.td @@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { let methods = [ InterfaceMethod<"Infer and set the output shape for the current operation.", - "void", "inferShapes"> + "bool", "inferShapes"> ]; } diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index e16fe89..ab05931 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -35,7 +35,11 @@ public: f.walk([&](mlir::Operation *op) { if (returnsDynamicShape(op)) { if (auto shape_op = dyn_cast(op)) { - shape_op.inferShapes(); + if (!shape_op.inferShapes()) { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } } else { op->emitError("unable to infer shape of operation without shape " "inference interface");