From 4f8fd9d1bf27760622770c10188b9e67222fb927 Mon Sep 17 00:00:00 2001 From: Alexandre Eichenberger Date: Tue, 26 May 2020 22:09:28 -0400 Subject: [PATCH] Error fix3 (#145) * removed warning missing return, dangling else * fixed errors, made sure to return false in all shape inference failures * shape inference use LogicalResults as return value * format fixed * format error Co-authored-by: Tian Jin --- .../ONNXToKrnl/Math/Elementwise.cpp | 1 + src/Conversion/ONNXToKrnl/Math/MatMul.cpp | 4 +- src/Dialect/Krnl/KrnlOps.hpp | 1 + src/Dialect/ONNX/ONNXOps.cpp | 660 ++++++++---------- src/Interface/ShapeInferenceInterface.td | 2 +- src/Transform/ONNX/ShapeInferencePass.cpp | 2 +- 6 files changed, 293 insertions(+), 377 deletions(-) diff --git a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp index 076bc25..ffa63ba 100644 --- a/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToKrnl/Math/Elementwise.cpp @@ -490,6 +490,7 @@ Value emitScalarOpFor(ConversionPatternRewriter &rewriter, return rewriter.create(loc, zero, operand); // 0 - X = -X } else { emitError(loc, "unsupported element type"); + return nullptr; } } diff --git a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp index 4a639e7..286221b 100644 --- a/src/Conversion/ONNXToKrnl/Math/MatMul.cpp +++ b/src/Conversion/ONNXToKrnl/Math/MatMul.cpp @@ -258,12 +258,12 @@ struct ONNXMatMulOpLowering : public ConversionPattern { for (auto arg : loopBatchIVs) loopBatchKNIVs.emplace_back(arg); loopBatchKNIVs.emplace_back(loopKIVs[0]); - if (BShape.size() >= 2) + if (BShape.size() >= 2) { if (AShape.size() >= 2) loopBatchKNIVs.emplace_back(loopMNIVs[1]); else loopBatchKNIVs.emplace_back(loopMNIVs[0]); - + } // Matmul computation auto loadedA = rewriter.create(loc, A, loopBatchMKIVs); auto loadedB = rewriter.create(loc, B, loopBatchKNIVs); diff --git a/src/Dialect/Krnl/KrnlOps.hpp b/src/Dialect/Krnl/KrnlOps.hpp index 351c8ca..b287883 100644 --- a/src/Dialect/Krnl/KrnlOps.hpp +++ b/src/Dialect/Krnl/KrnlOps.hpp @@ -32,6 +32,7 @@ public: return LoopType::get(parser.getBuilder().getContext()); parser.emitError(parser.getCurrentLocation(), "Unknown type"); + return nullptr; } /// Print a type registered to this dialect. diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 6cedbae..c6efad0 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -124,18 +124,22 @@ RankedTensorType getReductionOutputType( // Support function that computes default values for dilations. // template -static void processConvDilationParam(T *op, Optional kernelShape) { +static LogicalResult processConvDilationParam( + T *op, Optional kernelShape) { auto builder = mlir::Builder(op->getContext()); auto kernelRank = ArrayAttrSize(kernelShape); auto dilationsOpt = op->dilations(); if (dilationsOpt.hasValue()) { - if (ArrayAttrSize(dilationsOpt) != kernelRank) - op->emitError("dialation rank is not the same as the spatial rank"); + if (ArrayAttrSize(dilationsOpt) != kernelRank) { + return op->emitError( + "dialation rank is not the same as the spatial rank"); + } // Test values to be greater than 0. for (int i = 0; i < kernelRank; ++i) { - if (ArrayAttrIntVal(dilationsOpt, i) < 1) - op->emitError("dialation value must be nonzero positive"); + if (ArrayAttrIntVal(dilationsOpt, i) < 1) { + return op->emitError("dialation value must be nonzero positive"); + } } } else { // Default dilatation is needed, all dimensions init with 1. @@ -144,24 +148,26 @@ static void processConvDilationParam(T *op, Optional kernelShape) { ArrayRef defaultRefs(defaultVals); op->dilationsAttr(builder.getI64ArrayAttr(defaultRefs)); } + return success(); } //===----------------------------------------------------------------------===// // Support function that computes default values for strides. // template -static void processConvStrideParam(T *op, Optional kernelShape) { +static LogicalResult processConvStrideParam( + T *op, Optional kernelShape) { auto builder = mlir::Builder(op->getContext()); auto kernelRank = ArrayAttrSize(kernelShape); auto stridesOpt = op->strides(); if (stridesOpt.hasValue()) { if (ArrayAttrSize(stridesOpt) != kernelRank) - op->emitError("strides rank is not the same as the spatial rank"); + return op->emitError("strides rank is not the same as the spatial rank"); // Check values to be greater than 0. for (int i = 0; i < kernelRank; ++i) { if (ArrayAttrIntVal(stridesOpt, i) < 1) - op->emitError("strides value must be nonzero positive"); + return op->emitError("strides value must be nonzero positive"); } } else { // Default stride is needed, all dimensions init with 1. @@ -170,13 +176,14 @@ static void processConvStrideParam(T *op, Optional kernelShape) { ArrayRef defaultRefs(defaultVals); op->stridesAttr(builder.getI64ArrayAttr(defaultRefs)); } + return success(); } //===----------------------------------------------------------------------===// // Support function that computes default values for pads. // template -static void processConvPadParam(T *op, ArrayRef inputShape, +static LogicalResult processConvPadParam(T *op, ArrayRef inputShape, Optional kernelShape, Optional stridesOpt, Optional dilationsOpt = llvm::None) { auto builder = mlir::Builder(op->getContext()); @@ -196,12 +203,14 @@ static void processConvPadParam(T *op, ArrayRef inputShape, if (padsOpt.hasValue()) { // Only option where pads are not updated. Pads consists of two entries // for each spatial axis. - if (ArrayAttrSize(padsOpt) != 2 * kernelRank) - op->emitError("pads rank is not twice the spatial rank"); + if (ArrayAttrSize(padsOpt) != 2 * kernelRank) { + return op->emitError("pads rank is not twice the spatial rank"); + } // Check values, pads cannot be negative. for (int i = 0; i < 2 * kernelRank; ++i) { - if (ArrayAttrIntVal(padsOpt, i) < 0) - op->emitError("pads value must be nonnegative"); + if (ArrayAttrIntVal(padsOpt, i) < 0) { + return op->emitError("pads value must be nonnegative"); + } } } else { // We have notset with no pads, they are assumed to be all zero. @@ -251,7 +260,7 @@ static void processConvPadParam(T *op, ArrayRef inputShape, // No pad, default value was set to zero, we are all set. updatedPad = true; } else { - op->emitError("auto_pad of unknown / unsupported value"); + return op->emitError("auto_pad of unknown / unsupported value"); } // Set pads values in attributes, if it is needed. if (updatedPad) { @@ -260,13 +269,14 @@ static void processConvPadParam(T *op, ArrayRef inputShape, } // In all cases now, the acutal pad values are found in the pads attribute. op->auto_padAttr(builder.getStringAttr("NOTSET")); + return success(); } //===----------------------------------------------------------------------===// // Support function that computes default values for dilations, strides, and // pads. template -static void processConvTypeParams(T *op, Value inputOperand) { +static LogicalResult processConvTypeParams(T *op, Value inputOperand) { auto builder = mlir::Builder(op->getContext()); // 1) Get shape of input. @@ -277,15 +287,20 @@ static void processConvTypeParams(T *op, Value inputOperand) { auto kernelShape = op->kernel_shape(); // Dilation. - processConvDilationParam(op, kernelShape); + LogicalResult res = processConvDilationParam(op, kernelShape); + if (failed(res)) + return res; auto dilationsOpt = op->dilations(); // Strides. - processConvStrideParam(op, kernelShape); + res = processConvStrideParam(op, kernelShape); + if (failed(res)) + return res; auto stridesOpt = op->strides(); // Pads. - processConvPadParam(op, inputShape, kernelShape, stridesOpt, dilationsOpt); + return processConvPadParam( + op, inputShape, kernelShape, stridesOpt, dilationsOpt); } //===----------------------------------------------------------------------===// @@ -321,15 +336,16 @@ static void insertConvSpatialDim(SmallVector *outputDims, //===----------------------------------------------------------------------===// // Support function that infers shape for RNN operations. template -static bool RNNShapeInference(T *op) { +static LogicalResult RNNShapeInference(T *op) { Value X = op->X(); Value W = op->W(); Value R = op->R(); if (!X.getType().isa() || !W.getType().isa() || - !R.getType().isa()) - return false; + !R.getType().isa()) { + return op->emitError("Input tensor not ranked"); + } auto xTy = X.getType().cast(); auto elementType = xTy.getElementType(); @@ -342,16 +358,13 @@ static bool RNNShapeInference(T *op) { auto rShape = R.getType().cast().getShape(); if (xShape.size() != 3) { - op->emitError("The first input tensor must have rank 3"); - return false; + return op->emitError("The first input tensor must have rank 3"); } if (wShape.size() != 3) { - op->emitError("The second input tensor must have rank 3"); - return false; + return op->emitError("The second input tensor must have rank 3"); } if (rShape.size() != 3) { - op->emitError("The third input tensor must have rank 3"); - return false; + return op->emitError("The third input tensor must have rank 3"); } // Get sequence length, batch size and input size. @@ -387,9 +400,9 @@ static bool RNNShapeInference(T *op) { else numDirection = -1; if (numDirection == -1) { - op->emitError("direction attribute muse be one of the strings: forward, " - "reverse, and bidirectional"); - return false; + return op->emitError( + "direction attribute muse be one of the strings: forward, " + "reverse, and bidirectional"); } // Set result types. @@ -421,7 +434,7 @@ static bool RNNShapeInference(T *op) { op->getResults()[2].setType(ycTy); } } - return true; + return success(); } //===----------------------------------------------------------------------===// @@ -465,283 +478,269 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location, // Exp /// Infer the output shape of the ONNXExpOp. This method is required by the /// shape inference interface. -bool ONNXExpOp::inferShapes() { +LogicalResult ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Tanh /// Infer the output shape of the ONNXTanhOp. This method is required by the /// shape inference interface. -bool ONNXTanhOp::inferShapes() { +LogicalResult ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Sinh /// Infer the output shape of the ONNXSinhOp. This method is required by the /// shape inference interface. -bool ONNXSinhOp::inferShapes() { +LogicalResult ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Cosh /// Infer the output shape of the ONNXCoshOp. This method is required by the /// shape inference interface. -bool ONNXCoshOp::inferShapes() { +LogicalResult ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Cos /// Infer the output shape of the ONNXCosOp. This method is required by the /// shape inference interface. -bool ONNXCosOp::inferShapes() { +LogicalResult ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Log /// Infer the output shape of the ONNXLogOp. This method is required by the /// shape inference interface. -bool ONNXLogOp::inferShapes() { +LogicalResult ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // HardSigmoid /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// the shape inference interface. -bool ONNXHardSigmoidOp::inferShapes() { +LogicalResult ONNXHardSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Sigmoid /// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// shape inference interface. -bool ONNXSigmoidOp::inferShapes() { +LogicalResult ONNXSigmoidOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Elu /// Infer the output shape of the ONNXEluOp. This method is required by the /// shape inference interface. -bool ONNXEluOp::inferShapes() { +LogicalResult ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Relu /// Infer the output shape of the ONNXReluOp. This method is required by the /// shape inference interface. -bool ONNXReluOp::inferShapes() { +LogicalResult ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // LeakyRelu /// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// the shape inference interface. -bool ONNXLeakyReluOp::inferShapes() { +LogicalResult ONNXLeakyReluOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Selu /// Infer the output shape of the ONNXSeluOp. This method is required by /// the shape inference interface. -bool ONNXSeluOp::inferShapes() { +LogicalResult ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Reciprocal /// Infer the output shape of the ONNXReciprocalOp. This method is required by /// the shape inference interface. -bool ONNXReciprocalOp::inferShapes() { +LogicalResult ONNXReciprocalOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Softmax /// Infer the output shape of the ONNXSoftmaxOp. This method is required by /// the shape inference interface. -bool ONNXSoftmaxOp::inferShapes() { +LogicalResult ONNXSoftmaxOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Softplus /// Infer the output shape of the ONNXSoftplusOp. This method is required by /// the shape inference interface. -bool ONNXSoftplusOp::inferShapes() { +LogicalResult ONNXSoftplusOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Softsign /// Infer the output shape of the ONNXSoftsignOp. This method is required by /// the shape inference interface. -bool ONNXSoftsignOp::inferShapes() { +LogicalResult ONNXSoftsignOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Sqrt /// Infer the output shape of the ONNXSqrtOp. This method is required by /// the shape inference interface. -bool ONNXSqrtOp::inferShapes() { +LogicalResult ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Sign /// Infer the output shape of the ONNXSignOp. This method is required by /// the shape inference interface. -bool ONNXSignOp::inferShapes() { +LogicalResult ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Abs /// Infer the output shape of the ONNXAbsOp. This method is required by the /// shape inference interface. -bool ONNXAbsOp::inferShapes() { +LogicalResult ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Add /// Infer the output shape of the ONNXAddOp. This method is required by the /// shape inference interface. -bool ONNXAddOp::inferShapes() { +LogicalResult ONNXAddOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); - return true; + return success(); } //===----------------------------------------------------------------------===// // Mul /// Infer the output shape of the ONNXMulOp. This method is required by the /// shape inference interface. -bool ONNXMulOp::inferShapes() { +LogicalResult ONNXMulOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); - return true; + return success(); } //===----------------------------------------------------------------------===// // Div /// Infer the output shape of the ONNXDivOp. This method is required by the /// shape inference interface. -bool ONNXDivOp::inferShapes() { +LogicalResult ONNXDivOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); - return true; + return success(); } //===----------------------------------------------------------------------===// // Sub /// Infer the output shape of the ONNXSubOp. This method is required by the /// shape inference interface. -bool ONNXSubOp::inferShapes() { +LogicalResult ONNXSubOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); - return true; + return success(); } //===----------------------------------------------------------------------===// // And /// Infer the output shape of the ONNXAndOp. This method is required by the /// shape inference interface. -bool ONNXAndOp::inferShapes() { +LogicalResult ONNXAndOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); - return true; + return success(); } //===----------------------------------------------------------------------===// // Or /// Infer the output shape of the ONNXOrOp. This method is required by the /// shape inference interface. -bool ONNXOrOp::inferShapes() { +LogicalResult ONNXOrOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); - return true; + return success(); } //===----------------------------------------------------------------------===// // Xor /// Infer the output shape of the ONNXXorOp. This method is required by the /// shape inference interface. -bool ONNXXorOp::inferShapes() { +LogicalResult ONNXXorOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !getOperand(1).getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); - return true; + return success(); } //===----------------------------------------------------------------------===// @@ -750,12 +749,10 @@ bool ONNXXorOp::inferShapes() { // Sum /// Infer the output shape of the ONNXSumOp. This method is required by the /// shape inference interface. -bool ONNXSumOp::inferShapes() { +LogicalResult ONNXSumOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i).getType().cast()) { - emitError("Input tensor(s) not ranked"); - return false; - } + if (!getOperand(i).getType().cast()) + return emitError("Input tensor(s) not ranked"); } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -763,19 +760,17 @@ bool ONNXSumOp::inferShapes() { resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); - return true; + return success(); } //===----------------------------------------------------------------------===// // Max /// Infer the output shape of the ONNXMaxOp. This method is required by the /// shape inference interface. -bool ONNXMaxOp::inferShapes() { +LogicalResult ONNXMaxOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i).getType().cast()) { - emitError("Input tensor(s) not ranked"); - return false; - } + if (!getOperand(i).getType().cast()) + return emitError("Input tensor(s) not ranked"); } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -783,19 +778,17 @@ bool ONNXMaxOp::inferShapes() { resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); - return true; + return success(); } //===----------------------------------------------------------------------===// // Min /// Infer the output shape of the ONNXMinOp. This method is required by the /// shape inference interface. -bool ONNXMinOp::inferShapes() { +LogicalResult ONNXMinOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i).getType().cast()) { - emitError("Input tensor(s) not ranked"); - return false; - } + if (!getOperand(i).getType().cast()) + return emitError("Input tensor(s) not ranked"); } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -803,38 +796,36 @@ bool ONNXMinOp::inferShapes() { resultTy = getBroadcastedType(resultTy, nextTy); } getResult().setType(resultTy); - return true; + return success(); } //===----------------------------------------------------------------------===// // Neg /// Infer the output shape of the ONNXNegOp. This method is required by the /// shape inference interface. -bool ONNXNegOp::inferShapes() { +LogicalResult ONNXNegOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // Identity /// Infer the output shape of the ONNXIdentityOp. This method is required by the /// shape inference interface. -bool ONNXIdentityOp::inferShapes() { +LogicalResult ONNXIdentityOp::inferShapes() { getResult().setType(getOperand().getType()); - return true; + return success(); } //===----------------------------------------------------------------------===// // MatMul -bool ONNXMatMulOp::inferShapes() { +LogicalResult ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. if (!A().getType().isa() || - !B().getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !B().getType().isa()) + return emitError("Input tensor(s) not ranked"); auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); @@ -845,14 +836,14 @@ bool ONNXMatMulOp::inferShapes() { if (lhsShape.size() < 1 && rhsShape.size() < 1) { // Multiplication by scalars is not allowed. - emitError("Multiplication by scalar arguments not allowed"); + return emitError("Multiplication by scalar arguments not allowed"); } else if (lhsShape.size() == 1 && rhsShape.size() == 1) { // Special case when both arrays are 1-dimensional and according to // numpy rules the types need to be extended to 1xN and Nx1. Helper sizes // need to be removed after the multiplication but cannot be removed if all // sizes are 1. if (lhsShape[0] != -1 && rhsShape[0] != -1 && lhsShape[0] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices"); + return emitError("Attempt to multiply incompatible matrices"); dims.emplace_back(1); } else if (lhsShape.size() == 1 && rhsShape.size() >= 2) { // If the first argument is 1-D, it is promoted to a matrix by prepending a @@ -867,8 +858,7 @@ bool ONNXMatMulOp::inferShapes() { unsigned rhsRank = rhsShape.size(); if (lhsShape[0] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[0] != rhsShape[rhsRank - 2]) - emitError("Attempt to multiply incompatible matrices"); - + return emitError("Attempt to multiply incompatible matrices"); for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); dims.emplace_back(rhsShape[rhsRank - 1]); @@ -885,8 +875,7 @@ bool ONNXMatMulOp::inferShapes() { unsigned lhsRank = lhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && lhsShape[lhsRank - 1] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices"); - + return emitError("Attempt to multiply incompatible matrices"); for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(lhsShape[lhsRank - 2]); @@ -899,8 +888,7 @@ bool ONNXMatMulOp::inferShapes() { unsigned lhsRank = lhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[0] != -1 && lhsShape[lhsRank - 1] != rhsShape[0]) - emitError("Attempt to multiply incompatible matrices"); - + return emitError("Attempt to multiply incompatible matrices"); for (decltype(lhsRank) i = 0; i < lhsRank - 1; ++i) dims.emplace_back(lhsShape[i]); dims.emplace_back(rhsShape[1]); @@ -913,8 +901,7 @@ bool ONNXMatMulOp::inferShapes() { unsigned rhsRank = rhsShape.size(); if (lhsShape[1] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[1] != rhsShape[rhsRank - 2]) - emitError("Attempt to multiply incompatible matrices"); - + return emitError("Attempt to multiply incompatible matrices"); for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) dims.emplace_back(rhsShape[i]); dims.emplace_back(lhsShape[0]); @@ -929,8 +916,7 @@ bool ONNXMatMulOp::inferShapes() { unsigned rhsRank = rhsShape.size(); if (lhsShape[lhsRank - 1] != -1 && rhsShape[rhsRank - 2] != -1 && lhsShape[lhsRank - 1] != rhsShape[rhsRank - 2]) - emitError("Attempt to multiply incompatible matrices"); - + return emitError("Attempt to multiply incompatible matrices"); // Check and perform broadcasting for the shapes. SmallVector lhsBcastShape; for (decltype(lhsRank) i = 0; i < lhsRank - 2; ++i) @@ -939,8 +925,7 @@ bool ONNXMatMulOp::inferShapes() { for (decltype(rhsRank) i = 0; i < rhsRank - 2; ++i) rhsBcastShape.emplace_back(rhsShape[i]); if (!getBroadcastedShape(lhsBcastShape, rhsBcastShape, dims)) - emitError("Broadcasted dimensions are incompatible"); - + return emitError("Broadcasted dimensions are incompatible"); dims.emplace_back(lhsShape[lhsRank - 2]); dims.emplace_back(rhsShape[rhsRank - 1]); } else { @@ -954,29 +939,26 @@ bool ONNXMatMulOp::inferShapes() { // Check legality of matrix multiplication. if (lhsDim != -1 && rhsDim != -1 && lhsDim != rhsDim) - emitError("Attempt to multiply incompatible matrices"); - + return emitError("Attempt to multiply incompatible matrices"); if (rhsShape.size() > 1) dims.emplace_back(rhsShape[1]); } getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// // Gemm -bool ONNXGemmOp::inferShapes() { +LogicalResult ONNXGemmOp::inferShapes() { bool hasBias = !C().getType().isa(); // Cannot infer shape if no shape exists. if (!A().getType().isa() || !B().getType().isa() || - (hasBias && !C().getType().isa())) { - emitError("Input tensor(s) not ranked"); - return false; - } + (hasBias && !C().getType().isa())) + return emitError("Input tensor(s) not ranked"); auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); @@ -986,9 +968,8 @@ bool ONNXGemmOp::inferShapes() { N = (transB() == 0) ? rhsTy.getShape()[1] : rhsTy.getShape()[0]; K_B = (transB() == 0) ? rhsTy.getShape()[0] : rhsTy.getShape()[1]; - if ((K_A != -1) && (K_B != -1) && (K_A != K_B)) { - emitError("Tensor shapes mismatched"); - } + if ((K_A != -1) && (K_B != -1) && (K_A != K_B)) + return emitError("Tensor shapes mismatched"); if (hasBias) { // Check whether bias is unidirectional broadcasting or not. @@ -999,29 +980,26 @@ bool ONNXGemmOp::inferShapes() { (rank >= 1 && shape[rank - 1] != -1 && N != -1 && N != shape[rank - 1] && shape[rank - 1] != 1) || (rank == 2 && shape[rank - 2] != -1 && M != -1 && - M != shape[rank - 2] && shape[rank - 2] != 1)) { - emitError("Bias shape mismatched"); - } + M != shape[rank - 2] && shape[rank - 2] != 1)) + return emitError("Bias shape mismatched"); } SmallVector dims; dims.emplace_back(M); dims.emplace_back(N); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); - return true; + return success(); } /// BatchNormalizationTestMode -bool ONNXBatchNormalizationTestModeOp::inferShapes() { +LogicalResult ONNXBatchNormalizationTestModeOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa() || !scale().getType().isa() || !B().getType().isa() || !mean().getType().isa() || - !var().getType().isa()) { - emitError("Input tensor(s) not ranked"); - return false; - } + !var().getType().isa()) + return emitError("Input tensor(s) not ranked"); auto inputTensorTy = X().getType().cast(); auto scaleTensorTy = scale().getType().cast(); @@ -1039,7 +1017,7 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() { } else if (inputTensorTy.getShape().size() > 2) { c = (inputTensorTy.getShape()[1] != -1) ? inputTensorTy.getShape()[1] : -1; } else { - emitError("Wrong rank for the input"); + return emitError("Wrong rank for the input"); } if (c != -1) { @@ -1049,18 +1027,18 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() { auto v = varianceTensorTy.getShape(); if ((s.size() != 1) || (s[0] != -1 && s[0] != c)) - emitError("Wrong rank for the scale"); + return emitError("Wrong rank for the scale"); if ((b.size() != 1) || (b[0] != -1 && b[0] != c)) - emitError("Wrong rank for the bias"); + return emitError("Wrong rank for the bias"); if ((m.size() != 1) || (m[0] != -1 && m[0] != c)) - emitError("Wrong rank for the mean"); + return emitError("Wrong rank for the mean"); if ((v.size() != 1) || (v[0] != -1 && v[0] != c)) - emitError("Wrong rank for the variance"); + return emitError("Wrong rank for the variance"); } // The output tensor of the same shape as the input. getResult().setType(X().getType()); - return true; + return success(); } // TODO: @@ -1071,31 +1049,25 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() { // Reshape -bool ONNXReshapeOp::inferShapes() { +LogicalResult ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. - if (!data().getType().isa()) { - emitError("Input data tensor not ranked"); - return false; - } + if (!data().getType().isa()) + return emitError("Input data tensor not ranked"); - if (!shape().getType().isa()) { - emitError("Shape tensor not ranked"); - return false; - } + if (!shape().getType().isa()) + return emitError("Shape tensor not ranked"); auto inputTensorTy = data().getType().cast(); auto shapeTensorTy = shape().getType().cast(); // Only rank 1 shape tensors are supported. if (shapeTensorTy.getShape().size() != 1) - emitError("Shape tensor must have rank one"); - + return emitError("Shape tensor must have rank one"); int64_t outputRank = shapeTensorTy.getShape()[0]; // Shape tensor must have constant shape. if (outputRank < 0) - emitError("Shape tensor must have constant shape"); - + return emitError("Shape tensor must have constant shape"); // Compute total number of elements. int64_t totalInputSize = 1; for (auto inputDim : inputTensorTy.getShape()) @@ -1110,16 +1082,14 @@ bool ONNXReshapeOp::inferShapes() { constantOp.valueAttr().dyn_cast(); if (!valueAttribute) - emitError("DenseElementsAttr expected"); - + return emitError("DenseElementsAttr expected"); // Get dims from valueAttribute. auto valueIt = valueAttribute.getValues().begin(); for (int i = 0; i < outputRank; ++i) dims[i] = (*valueIt++).cast().getInt(); if (valueIt != valueAttribute.getValues().end()) - emitError("Constant value must have same rank as output"); - + return emitError("Constant value must have same rank as output"); int64_t numberOfDynamicInputs = 0; int64_t totalKnownDimsSize = 1; int64_t dynamicValueIndex = -1; @@ -1147,19 +1117,17 @@ bool ONNXReshapeOp::inferShapes() { getResult().setType( RankedTensorType::get(dims, inputTensorTy.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// // Transpose -bool ONNXTransposeOp::inferShapes() { +LogicalResult ONNXTransposeOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!data().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } + if (!data().getType().isa()) + return emitError("Input tensor not ranked"); // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). @@ -1177,67 +1145,59 @@ bool ONNXTransposeOp::inferShapes() { } getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// // ReduceMax -bool ONNXReduceMaxOp::inferShapes() { - if (!getOperand().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } +LogicalResult ONNXReduceMaxOp::inferShapes() { + if (!getOperand().getType().isa()) + return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); - return true; + return success(); } //===----------------------------------------------------------------------===// // ReduceMin -bool ONNXReduceMinOp::inferShapes() { - if (!getOperand().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } +LogicalResult ONNXReduceMinOp::inferShapes() { + if (!getOperand().getType().isa()) + return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); - return true; + return success(); } //===----------------------------------------------------------------------===// // ReduceProd -bool ONNXReduceProdOp::inferShapes() { - if (!getOperand().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } +LogicalResult ONNXReduceProdOp::inferShapes() { + if (!getOperand().getType().isa()) + return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); - return true; + return success(); } //===----------------------------------------------------------------------===// // ReduceSum -bool ONNXReduceSumOp::inferShapes() { - if (!getOperand().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } +LogicalResult ONNXReduceSumOp::inferShapes() { + if (!getOperand().getType().isa()) + return emitError("Input tensor not ranked"); auto operandTy = getOperand().getType().cast(); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); - return true; + return success(); } //===----------------------------------------------------------------------===// @@ -1253,7 +1213,7 @@ bool ONNXReduceSumOp::inferShapes() { // - kernelShape: inferred from weight matrix if not defined by user; // - pads: set to proper value, 0 if not defined by user. -bool ONNXConvOp::inferShapes() { +LogicalResult ONNXConvOp::inferShapes() { // Generic shape for data input X, weight tensor W, and optional bias B // X: (N x C x D1 x D2 ... x Dn) // W: (M x C/group x k1 x k2 x ... x kn) @@ -1264,10 +1224,8 @@ bool ONNXConvOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa() || !W().getType().isa() || - (hasBias && !B().getType().isa())) { - emitError("Input tensor not ranked"); - return false; - } + (hasBias && !B().getType().isa())) + return emitError("Input tensor not ranked"); auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); @@ -1276,16 +1234,12 @@ bool ONNXConvOp::inferShapes() { auto builder = mlir::Builder(this->getContext()); // Lowest supported convolution is a one dimensional convolution. - if (xShape.size() < 3) { - emitError("Data input shape must be at least (NxCxD1)"); - return false; - } + if (xShape.size() < 3) + return emitError("Data input shape must be at least (NxCxD1)"); // Check that shape of weight and data have same length. - if (xShape.size() != weightShape.size()) { - emitError("Weight size not compatible with data size"); - return false; - } + if (xShape.size() != weightShape.size()) + return emitError("Weight size not compatible with data size"); // Group is a required attribute and should have default value of 1. int64_t group = ONNXConvOp::group().getSExtValue(); @@ -1296,23 +1250,18 @@ bool ONNXConvOp::inferShapes() { // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. if (xShape[1] != -1 && weightShape[1] != -1 && - xShape[1] != (weightShape[1] * group)) { - emitError("Channel dimension mismatch"); - return false; - } + xShape[1] != (weightShape[1] * group)) + return emitError("Channel dimension mismatch"); // Check the size of bias. if (hasBias) { auto bTx = B().getType().cast(); auto bShape = bTx.getShape(); - if (bShape.size() != 1) { - emitError("bias should be one dimensional"); - return false; - } - if (bShape[0] != weightShape[0]) { - emitError("bias should have same dimensions as weight's first dimension"); - return false; - } + if (bShape.size() != 1) + return emitError("bias should be one dimensional"); + if (bShape[0] != weightShape[0]) + return emitError("bias should have same dimensions " + "as weight's first dimension"); } // Note: the value of the group attribut only impacts the way the @@ -1326,16 +1275,13 @@ bool ONNXConvOp::inferShapes() { // argument. auto kernelShape = kernel_shape(); if (kernelShape.hasValue()) { - if (ArrayAttrSize(kernelShape) != spatialRank) { - emitError("kernel_shape length incompatible with spatial dimensions"); - return false; - } + if (ArrayAttrSize(kernelShape) != spatialRank) + return emitError( + "kernel_shape length incompatible with spatial dimensions"); // Have the right number of values, check them. for (int i = 0; i < spatialRank; ++i) - if (ArrayAttrIntVal(kernelShape, i) < 1) { - emitError("bad kernel_shape value"); - return false; - } + if (ArrayAttrIntVal(kernelShape, i) < 1) + return emitError("bad kernel_shape value"); } else { // Deduce shape from weight input. SmallVector defaultVals; @@ -1366,7 +1312,7 @@ bool ONNXConvOp::inferShapes() { stridesOpt, dilationsOpt); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// @@ -1377,12 +1323,10 @@ bool ONNXConvOp::inferShapes() { // - strides: set to 1 if not defined by user; // - pads: set to proper value, 0 if not defined by user. -bool ONNXAveragePoolOp::inferShapes() { +LogicalResult ONNXAveragePoolOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!X().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } + if (!X().getType().isa()) + return emitError("Input tensor not ranked"); auto builder = mlir::Builder(getContext()); @@ -1393,17 +1337,22 @@ bool ONNXAveragePoolOp::inferShapes() { // Kernel shape. auto kernelShape = kernel_shape(); if (!kernelShape) - emitError( + return emitError( "kernel_shape is a mandatory attribute for which there is no default"); // Ceil mode. auto ceilMode = ceil_mode().getSExtValue(); // Process strides and pads. - processConvStrideParam(this, kernelShape); + LogicalResult res = + processConvStrideParam(this, kernelShape); + if (failed(res)) + return res; auto stridesOpt = strides(); - processConvPadParam( + res = processConvPadParam( this, xShape, kernelShape, stridesOpt, llvm::None); + if (failed(res)) + return res; auto padsOpt = pads(); SmallVector outputDims; @@ -1415,7 +1364,7 @@ bool ONNXAveragePoolOp::inferShapes() { stridesOpt, llvm::None, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// @@ -1426,12 +1375,10 @@ bool ONNXAveragePoolOp::inferShapes() { // - dilations, strides: set to 1 if not defined by user; // - pads: set to proper value, 0 if not defined by user. -bool ONNXMaxPoolSingleOutOp::inferShapes() { +LogicalResult ONNXMaxPoolSingleOutOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!X().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } + if (!X().getType().isa()) + return emitError("Input tensor not ranked"); auto builder = mlir::Builder(getContext()); @@ -1442,13 +1389,13 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() { // Kernel shape. auto kernelShape = kernel_shape(); if (!kernelShape) - emitError( + return emitError( "kernel_shape is a mandatory attribute for which there is no default"); // Storage order. auto storageOrder = storage_order().getSExtValue(); if (storageOrder != 0) - emitError("column major storage order not supported at this time"); + return emitError("column major storage order not supported at this time"); // Process strides, dilations, and pads. processConvTypeParams<>(this, X()); @@ -1468,26 +1415,21 @@ bool ONNXMaxPoolSingleOutOp::inferShapes() { stridesOpt, dilationsOpt, ceilMode); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// -bool ONNXPadOp::inferShapes() { +LogicalResult ONNXPadOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!data().getType().isa()) { - emitError("Pad: unknown input shape"); - return false; - } + if (!data().getType().isa()) + return emitError("Pad: unknown input shape"); // Cannot infer if the pads is not constant DenseElementsAttr padsAttributes = getAttr("pads").dyn_cast_or_null(); - - if (!padsAttributes) { - emitError("Pad: unknown pads"); - return false; - } + if (!padsAttributes) + return emitError("Pad: unknown pads"); auto dataTy = data().getType().cast(); auto dataShape = dataTy.getShape(); @@ -1507,17 +1449,15 @@ bool ONNXPadOp::inferShapes() { int64_t p1 = pads[i]; int64_t p2 = pads[i + dataRank]; // Have to non-negative constant - if (p1 < 0 || p2 < 0) { - emitError("padding value can not be negative"); - return false; - } + if (p1 < 0 || p2 < 0) + return emitError("padding value can not be negative"); if (outputShape[i] != -1) outputShape[i] += p1 + p2; } auto outputType = RankedTensorType::get(outputShape, dataTy.getElementType()); getResult().setType(outputType); - return true; + return success(); } static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { @@ -1551,26 +1491,24 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) { // PadConstantPad -bool ONNXPadConstantPadOp::inferShapes() { +LogicalResult ONNXPadConstantPadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); - if (outputType) { - getResult().setType(outputType); - return true; - } - return false; + if (!outputType) + return emitError("missing output"); + getResult().setType(outputType); + return success(); } //===----------------------------------------------------------------------===// // PadConstantValuePad -bool ONNXPadConstantValuePadOp::inferShapes() { +LogicalResult ONNXPadConstantValuePadOp::inferShapes() { auto outputType = padShapeInferenceHelper(data(), pads()); - if (outputType) { - getResult().setType(outputType); - return true; - } - return false; + if (!outputType) + return emitError("missing output"); + getResult().setType(outputType); + return success(); } void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state, @@ -1587,11 +1525,9 @@ void ONNXPadConstantValuePadOp::build(OpBuilder &builder, OperationState &state, // Unsqueeze -bool ONNXUnsqueezeOp::inferShapes() { - if (!data().getType().isa()) { - emitError("Input tensor not ranked"); - return false; - } +LogicalResult ONNXUnsqueezeOp::inferShapes() { + if (!data().getType().isa()) + return emitError("Input tensor not ranked"); auto operandTy = data().getType().cast(); int inRank = operandTy.getRank(); @@ -1608,15 +1544,11 @@ bool ONNXUnsqueezeOp::inferShapes() { assert(axis >= -outRank && axis <= outRank - 1); if (std::find(axes.begin(), axes.end(), axis) == axes.end()) axes.emplace_back(axis); - else { - emitError("Duplicated axes"); - return false; - } + else + return emitError("Duplicated axes"); } - } else { - emitError("Axes attribute is required"); - return false; - } + } else + return emitError("Axes attribute is required"); SmallVector dims; for (int i = 0, j = 0; i < outRank || j < inRank; ++i) { @@ -1627,36 +1559,33 @@ bool ONNXUnsqueezeOp::inferShapes() { } } getResult().setType(RankedTensorType::get(dims, operandTy.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// // Constant -bool ONNXConstantOp::inferShapes() { +LogicalResult ONNXConstantOp::inferShapes() { if ((sparse_value().hasValue() && value().hasValue()) || (!sparse_value().hasValue() && !value().hasValue())) - emitError("Require exactly one of the two attributes, either value or " - "sparse_value"); - + return emitError("Require exactly one of the two attributes, " + "either value or sparse_value"); ElementsAttr valAttr; if (sparse_value().hasValue()) valAttr = sparse_valueAttr().cast(); else valAttr = valueAttr().cast(); getResult().setType(valAttr.getType()); - return true; + return success(); } // Concat -bool ONNXConcatOp::inferShapes() { +LogicalResult ONNXConcatOp::inferShapes() { int inputNum = getNumOperands(); for (int i = 0; i < inputNum; ++i) { - if (!getOperand(i).getType().cast()) { - emitError("Input tensor(s) not ranked"); - return false; - } + if (!getOperand(i).getType().cast()) + return emitError("Input tensor(s) not ranked"); } // Checking value of axis parameter. auto commonType = getOperand(0).getType().cast(); @@ -1669,34 +1598,28 @@ bool ONNXConcatOp::inferShapes() { auto builder = mlir::Builder(getContext()); axisAttr(builder.getI64IntegerAttr(axisIndex)); } - if (axisIndex >= commonRank) { - emitError("Concat axis value out of bound"); - return false; - } + if (axisIndex >= commonRank) + return emitError("Concat axis value out of bound"); // Initial cummlative size is that of the first operand. int cummulativeAxisSize = commonShape[axisIndex]; - // Compute the cummlative size with all of the other ones, and make sure that - // the other sizes are all alike. + // Compute the cummlative size with all of the other ones, and make sure + // that the other sizes are all alike. for (int i = 1; i < inputNum; ++i) { auto currShape = getOperand(i).getType().cast().getShape(); - if (currShape.size() != commonRank) { - emitError("Concat input must all have the same rank"); - return false; - } + if (currShape.size() != commonRank) + return emitError("Concat input must all have the same rank"); for (int j = 0; j < commonRank; ++j) { if (j == axisIndex) { // Check that the value is positive. - if (currShape[j] <= 0) { - emitError("Concat axis being concatenated is expected to be known at " - "compile time for now"); - return false; - } + if (currShape[j] <= 0) + return emitError("Concat axis being concatenated is " + "expected to be known at compile time for now"); } else if (currShape[j] != commonShape[j]) { - emitError("Concat input dimensions must be all identical, except for " - "dimension on the axis of the concatenation"); - return false; + return emitError( + "Concat input dimensions must be all identical, " + "except for dimension on the axis of the concatenation"); } } cummulativeAxisSize += currShape[axisIndex]; @@ -1709,32 +1632,30 @@ bool ONNXConcatOp::inferShapes() { j == axisIndex ? cummulativeAxisSize : commonShape[j]); getResult().setType( RankedTensorType::get(outputDims, commonType.getElementType())); - return true; + return success(); } //===----------------------------------------------------------------------===// // RNN -bool ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); } +LogicalResult ONNXRNNOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // LSTM -bool ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); } +LogicalResult ONNXLSTMOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // GRU -bool ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); } +LogicalResult ONNXGRUOp::inferShapes() { return RNNShapeInference<>(this); } //===----------------------------------------------------------------------===// // Split -bool ONNXSplitOp::inferShapes() { - if (!getOperand().getType().cast()) { - emitError("Input tensor not ranked"); - return false; - } +LogicalResult ONNXSplitOp::inferShapes() { + if (!getOperand().getType().cast()) + return emitError("Input tensor not ranked"); int numOfResults = getNumResults(); auto inputType = getOperand().getType().cast(); @@ -1743,10 +1664,8 @@ bool ONNXSplitOp::inferShapes() { // Checking value of axis parameter. auto axisIndex = axis().getSExtValue(); - if (axisIndex < -inputRank || axisIndex >= inputRank) { - emitError("Split axis value out of bound"); - return false; - } + if (axisIndex < -inputRank || axisIndex >= inputRank) + return emitError("Split axis value out of bound"); // Negative axis means values are counted from the opposite side. if (axisIndex < 0) { axisIndex = inputRank + axisIndex; @@ -1758,23 +1677,18 @@ bool ONNXSplitOp::inferShapes() { auto splitAttribute = split(); SmallVector splitLengths; if (splitAttribute.hasValue()) { - if (ArrayAttrSize(splitAttribute) != numOfResults) { - emitError("Split size not equal to the number of results"); - } + if (ArrayAttrSize(splitAttribute) != numOfResults) + return emitError("Split size not equal to the number of results"); for (int i = 0; i < numOfResults; ++i) splitLengths.emplace_back(ArrayAttrIntVal(splitAttribute, i)); } else { - if (inputShape[axisIndex] <= 0) { - emitError("The dimension at the split axis is expected to be known at " - "compile time"); - return false; - } - if (inputShape[axisIndex] % numOfResults != 0) { - emitError("The dimension at the split axis is expected to be divisible " - "by the number of results"); - return false; - } + if (inputShape[axisIndex] <= 0) + return emitError("The dimension at the split axis is " + "expected to be known at compile time"); + if (inputShape[axisIndex] % numOfResults != 0) + return emitError("The dimension at the split axis is " + "expected to be divisible by the number of results"); // If split parameter is not specified, the dimension is split to // equal-sized parts. for (int i = 0; i < numOfResults; ++i) @@ -1797,7 +1711,7 @@ bool ONNXSplitOp::inferShapes() { getResults()[i].setType( RankedTensorType::get(resultShape, inputType.getElementType())); } - return true; + return success(); } //===----------------------------------------------------------------------===// diff --git a/src/Interface/ShapeInferenceInterface.td b/src/Interface/ShapeInferenceInterface.td index 0d49090..5a493e7 100644 --- a/src/Interface/ShapeInferenceInterface.td +++ b/src/Interface/ShapeInferenceInterface.td @@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { let methods = [ InterfaceMethod<"Infer and set the output shape for the current operation.", - "bool", "inferShapes"> + "LogicalResult", "inferShapes"> ]; } diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index ad1def7..b156afd 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -36,7 +36,7 @@ public: f.walk([&](mlir::Operation *op) { if (returnsDynamicShape(op)) { if (auto shape_op = dyn_cast(op)) { - if (!shape_op.inferShapes()) { + if (failed(shape_op.inferShapes())) { op->emitError("unable to infer shape of operation without shape " "inference method"); return signalPassFailure();