Handle error shape inference (#47)

* add return to inferShape

* ran clang-format

* minor changes according to review

* fix format
This commit is contained in:
chentong319 2020-03-30 11:22:55 -04:00 committed by GitHub
parent 867406191f
commit 55cbe316fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 183 additions and 91 deletions

View File

@ -151,9 +151,8 @@ static void processConvStrideParam(T *op, Optional<ArrayAttr> kernelShape) {
// Support function that computes default values for pads.
//
template <class T>
static void processConvPadParam(T *op,
ArrayRef<int64_t> inputShape, Optional<ArrayAttr> kernelShape,
Optional<ArrayAttr> stridesOpt,
static void processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
Optional<ArrayAttr> dilationsOpt = llvm::None) {
auto builder = mlir::Builder(op->getContext());
@ -341,219 +340,271 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
// Exp
/// Infer the output shape of the ONNXExpOp. This method is required by the
/// shape inference interface.
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXExpOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface.
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXTanhOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface.
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXSinhOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface.
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXCoshOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Cos
/// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface.
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXCosOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Log
/// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface.
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXLogOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// HardSigmoid
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface.
void ONNXHardSigmoidOp::inferShapes() {
bool ONNXHardSigmoidOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Sigmoid
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
/// shape inference interface.
void ONNXSigmoidOp::inferShapes() {
bool ONNXSigmoidOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Elu
/// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface.
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXEluOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface.
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXReluOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// LeakyRelu
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface.
void ONNXLeakyReluOp::inferShapes() {
bool ONNXLeakyReluOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface.
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXSeluOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Reciprocal
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
/// the shape inference interface.
void ONNXReciprocalOp::inferShapes() {
bool ONNXReciprocalOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Softmax
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface.
void ONNXSoftmaxOp::inferShapes() {
bool ONNXSoftmaxOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Softplus
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
/// the shape inference interface.
void ONNXSoftplusOp::inferShapes() {
bool ONNXSoftplusOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Softsign
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
/// the shape inference interface.
void ONNXSoftsignOp::inferShapes() {
bool ONNXSoftsignOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface.
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXSqrtOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Sign
/// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface.
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXSignOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Abs
/// Infer the output shape of the ONNXAbsOp. This method is required by the
/// shape inference interface.
void ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); }
bool ONNXAbsOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// Add
/// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface.
void ONNXAddOp::inferShapes() {
bool ONNXAddOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("ONNXAddOp inferShapes failed");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
}
//===----------------------------------------------------------------------===//
// Mul
/// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface.
void ONNXMulOp::inferShapes() {
bool ONNXMulOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
}
//===----------------------------------------------------------------------===//
// Div
/// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface.
void ONNXDivOp::inferShapes() {
bool ONNXDivOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
}
//===----------------------------------------------------------------------===//
// Sub
/// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface.
void ONNXSubOp::inferShapes() {
bool ONNXSubOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
}
//===----------------------------------------------------------------------===//
// And
/// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface.
void ONNXAndOp::inferShapes() {
bool ONNXAndOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
}
//===----------------------------------------------------------------------===//
// Or
/// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface.
void ONNXOrOp::inferShapes() {
bool ONNXOrOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
}
//===----------------------------------------------------------------------===//
// Xor
/// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface.
void ONNXXorOp::inferShapes() {
bool ONNXXorOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
return;
return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
}
//===----------------------------------------------------------------------===//
@ -562,10 +613,10 @@ void ONNXXorOp::inferShapes() {
// Sum
/// Infer the output shape of the ONNXSumOp. This method is required by the
/// shape inference interface.
void ONNXSumOp::inferShapes() {
bool ONNXSumOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
return false;
}
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
@ -573,16 +624,17 @@ void ONNXSumOp::inferShapes() {
resultTy = getBroadcastedType(resultTy, nextTy);
}
getResult().setType(resultTy);
return true;
}
//===----------------------------------------------------------------------===//
// Max
/// Infer the output shape of the ONNXMaxOp. This method is required by the
/// shape inference interface.
void ONNXMaxOp::inferShapes() {
bool ONNXMaxOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
return false;
}
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
@ -590,16 +642,17 @@ void ONNXMaxOp::inferShapes() {
resultTy = getBroadcastedType(resultTy, nextTy);
}
getResult().setType(resultTy);
return true;
}
//===----------------------------------------------------------------------===//
// Min
/// Infer the output shape of the ONNXMinOp. This method is required by the
/// shape inference interface.
void ONNXMinOp::inferShapes() {
bool ONNXMinOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>())
return;
return false;
}
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
@ -607,25 +660,27 @@ void ONNXMinOp::inferShapes() {
resultTy = getBroadcastedType(resultTy, nextTy);
}
getResult().setType(resultTy);
return true;
}
//===----------------------------------------------------------------------===//
// Identity
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface.
void ONNXIdentityOp::inferShapes() {
bool ONNXIdentityOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===//
// MatMul
void ONNXMatMulOp::inferShapes() {
bool ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>())
return;
return false;
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
@ -752,19 +807,20 @@ void ONNXMatMulOp::inferShapes() {
}
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
// Gemm
void ONNXGemmOp::inferShapes() {
bool ONNXGemmOp::inferShapes() {
bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
(hasBias && !C().getType().isa<RankedTensorType>()))
return;
return false;
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
@ -796,17 +852,18 @@ void ONNXGemmOp::inferShapes() {
dims.emplace_back(M);
dims.emplace_back(N);
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
return true;
}
/// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() {
bool ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() ||
!scale().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
!mean().getType().isa<RankedTensorType>() ||
!var().getType().isa<RankedTensorType>())
return;
return false;
auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
@ -845,6 +902,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
// The output tensor of the same shape as the input.
getResult().setType(X().getType());
return true;
}
// TODO:
@ -855,7 +913,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Reshape
void ONNXReshapeOp::inferShapes() {
bool ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
if (!shape().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked");
@ -875,7 +933,7 @@ void ONNXReshapeOp::inferShapes() {
// Compute total number of elements.
int64_t totalInputSize = 1;
for(auto inputDim : inputTensorTy.getShape())
for (auto inputDim : inputTensorTy.getShape())
totalInputSize *= inputDim;
// Check if second argument of ReshapeOp is a constant.
@ -891,7 +949,7 @@ void ONNXReshapeOp::inferShapes() {
// Get dims from valueAttribute.
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
for (int i=0; i<outputRank; ++i)
for (int i = 0; i < outputRank; ++i)
dims[i] = (*valueIt++).cast<IntegerAttr>().getInt();
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
@ -900,7 +958,7 @@ void ONNXReshapeOp::inferShapes() {
int64_t numberOfDynamicInputs = 0;
int64_t totalKnownDimsSize = 1;
int64_t dynamicValueIndex = -1;
for (int i=0; i<outputRank; ++i) {
for (int i = 0; i < outputRank; ++i) {
// Set output dimension.
if (dims[i] == 0)
dims[i] = inputTensorTy.getShape()[i];
@ -924,16 +982,17 @@ void ONNXReshapeOp::inferShapes() {
getResult().setType(
RankedTensorType::get(dims, inputTensorTy.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
// Transpose
void ONNXTransposeOp::inferShapes() {
bool ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>())
return;
return false;
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
@ -951,62 +1010,67 @@ void ONNXTransposeOp::inferShapes() {
}
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
// ReduceMax
void ONNXReduceMaxOp::inferShapes() {
bool ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
return false;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
}
//===----------------------------------------------------------------------===//
// ReduceMin
void ONNXReduceMinOp::inferShapes() {
bool ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
return false;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
}
//===----------------------------------------------------------------------===//
// ReduceProd
void ONNXReduceProdOp::inferShapes() {
bool ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
return false;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
}
//===----------------------------------------------------------------------===//
// ReduceSum
void ONNXReduceSumOp::inferShapes() {
bool ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return;
return false;
}
auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
}
//===----------------------------------------------------------------------===//
@ -1022,7 +1086,7 @@ void ONNXReduceSumOp::inferShapes() {
// - kernelShape: inferred from weight matrix if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXConvOp::inferShapes() {
bool ONNXConvOp::inferShapes() {
// Generic shape for data input X, weight tensor W, and optional bias B
// X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn)
@ -1034,7 +1098,7 @@ void ONNXConvOp::inferShapes() {
if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>() ||
(hasBias && !B().getType().isa<RankedTensorType>()))
return;
return false;
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
@ -1043,12 +1107,16 @@ void ONNXConvOp::inferShapes() {
auto builder = mlir::Builder(this->getContext());
// Lowest supported convolution is a one dimensional convolution.
if (xShape.size() < 3)
if (xShape.size() < 3) {
emitError("Data input shape must be at least (NxCxD1)");
return false;
}
// Check that shape of weight and data have same length.
if (xShape.size() != weightShape.size())
if (xShape.size() != weightShape.size()) {
emitError("Weight size not compatible with data size");
return false;
}
// Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvOp::group().getSExtValue();
@ -1059,17 +1127,23 @@ void ONNXConvOp::inferShapes() {
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[1] != -1 &&
xShape[1] != (weightShape[1] * group))
xShape[1] != (weightShape[1] * group)) {
emitError("Channel dimension mismatch");
return false;
}
// Check the size of bias.
if (hasBias) {
auto bTx = B().getType().cast<RankedTensorType>();
auto bShape = bTx.getShape();
if (bShape.size() != 1)
if (bShape.size() != 1) {
emitError("bias should be one dimensional");
if (bShape[0] != weightShape[0])
return false;
}
if (bShape[0] != weightShape[0]) {
emitError("bias should have same dimensions as weight's first dimension");
return false;
}
}
// Note: the value of the group attribut only impacts the way the
@ -1083,12 +1157,16 @@ void ONNXConvOp::inferShapes() {
// argument.
auto kernelShape = kernel_shape();
if (kernelShape.hasValue()) {
if (ArrayAttrSize(kernelShape) != spatialRank)
if (ArrayAttrSize(kernelShape) != spatialRank) {
emitError("kernel_shape length incompatible with spatial dimensions");
return false;
}
// Have the right number of values, check them.
for (int i = 0; i < spatialRank; ++i)
if (ArrayAttrIntVal(kernelShape, i) < 1)
if (ArrayAttrIntVal(kernelShape, i) < 1) {
emitError("bad kernel_shape value");
return false;
}
} else {
// Deduce shape from weight input.
SmallVector<int64_t, 2> defaultVals;
@ -1119,6 +1197,7 @@ void ONNXConvOp::inferShapes() {
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
@ -1129,10 +1208,10 @@ void ONNXConvOp::inferShapes() {
// - strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXAveragePoolOp::inferShapes() {
bool ONNXAveragePoolOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>())
return;
return false;
// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
@ -1163,6 +1242,7 @@ void ONNXAveragePoolOp::inferShapes() {
llvm::None, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
@ -1173,10 +1253,10 @@ void ONNXAveragePoolOp::inferShapes() {
// - dilations, strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user.
void ONNXMaxPoolSingleOutOp::inferShapes() {
bool ONNXMaxPoolSingleOutOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>())
return;
return false;
// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
@ -1211,6 +1291,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
dilationsOpt, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
@ -1246,24 +1327,26 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
// PadConstantPad
void ONNXPadConstantPadOp::inferShapes() {
bool ONNXPadConstantPadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads());
if (outputType) {
getResult().setType(outputType);
return true;
}
return;
return false;
}
//===----------------------------------------------------------------------===//
// PadConstantValuePad
void ONNXPadConstantValuePadOp::inferShapes() {
bool ONNXPadConstantValuePadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads());
if (outputType) {
getResult().setType(outputType);
return true;
}
return;
return false;
}
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
@ -1280,9 +1363,9 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
// Unsqueeze
void ONNXUnsqueezeOp::inferShapes() {
bool ONNXUnsqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>())
return;
return false;
auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank();
@ -1299,11 +1382,14 @@ void ONNXUnsqueezeOp::inferShapes() {
assert(axis >= -outRank && axis <= outRank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis);
else
else {
emitError("Duplicated axes");
return false;
}
}
} else {
emitError("Axes attribute is required");
return false;
}
SmallVector<int64_t, 4> dims;
@ -1315,12 +1401,13 @@ void ONNXUnsqueezeOp::inferShapes() {
}
}
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
return true;
}
//===----------------------------------------------------------------------===//
// Constant
void ONNXConstantOp::inferShapes() {
bool ONNXConstantOp::inferShapes() {
if ((sparse_value().hasValue() && value().hasValue()) ||
(!sparse_value().hasValue() && !value().hasValue()))
emitError("Require exactly one of the two attributes, either value or "
@ -1332,6 +1419,7 @@ void ONNXConstantOp::inferShapes() {
else
valAttr = valueAttr().cast<DenseElementsAttr>();
getResult().setType(valAttr.getType());
return true;
}
//===----------------------------------------------------------------------===//

View File

@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes">
"bool", "inferShapes">
];
}

View File

@ -35,7 +35,11 @@ public:
f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op)) {
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
shape_op.inferShapes();
if (!shape_op.inferShapes()) {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
} else {
op->emitError("unable to infer shape of operation without shape "
"inference interface");