Handle error shape inference (#47)
* add return to inferShape * ran clang-format * minor changes according to review * fix format
This commit is contained in:
parent
867406191f
commit
55cbe316fd
|
@ -151,9 +151,8 @@ static void processConvStrideParam(T *op, Optional<ArrayAttr> kernelShape) {
|
|||
// Support function that computes default values for pads.
|
||||
//
|
||||
template <class T>
|
||||
static void processConvPadParam(T *op,
|
||||
ArrayRef<int64_t> inputShape, Optional<ArrayAttr> kernelShape,
|
||||
Optional<ArrayAttr> stridesOpt,
|
||||
static void processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
||||
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
|
||||
Optional<ArrayAttr> dilationsOpt = llvm::None) {
|
||||
auto builder = mlir::Builder(op->getContext());
|
||||
|
||||
|
@ -256,7 +255,7 @@ static void processConvTypeParams(T *op, Value inputOperand) {
|
|||
processConvDilationParam<T>(op, kernelShape);
|
||||
auto dilationsOpt = op->dilations();
|
||||
|
||||
// Strides.
|
||||
// Strides.
|
||||
processConvStrideParam<T>(op, kernelShape);
|
||||
auto stridesOpt = op->strides();
|
||||
|
||||
|
@ -341,219 +340,271 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
|||
// Exp
|
||||
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXExpOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tanh
|
||||
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXTanhOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sinh
|
||||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXSinhOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cosh
|
||||
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXCoshOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Cos
|
||||
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXCosOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Log
|
||||
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXLogOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// HardSigmoid
|
||||
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXHardSigmoidOp::inferShapes() {
|
||||
bool ONNXHardSigmoidOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sigmoid
|
||||
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSigmoidOp::inferShapes() {
|
||||
bool ONNXSigmoidOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Elu
|
||||
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXEluOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Relu
|
||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXReluOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LeakyRelu
|
||||
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXLeakyReluOp::inferShapes() {
|
||||
bool ONNXLeakyReluOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Selu
|
||||
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXSeluOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reciprocal
|
||||
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXReciprocalOp::inferShapes() {
|
||||
bool ONNXReciprocalOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softmax
|
||||
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSoftmaxOp::inferShapes() {
|
||||
bool ONNXSoftmaxOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softplus
|
||||
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSoftplusOp::inferShapes() {
|
||||
bool ONNXSoftplusOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softsign
|
||||
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSoftsignOp::inferShapes() {
|
||||
bool ONNXSoftsignOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sqrt
|
||||
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXSqrtOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sign
|
||||
/// Infer the output shape of the ONNXSignOp. This method is required by
|
||||
/// the shape inference interface.
|
||||
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXSignOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Abs
|
||||
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
||||
bool ONNXAbsOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Add
|
||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAddOp::inferShapes() {
|
||||
bool ONNXAddOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("ONNXAddOp inferShapes failed");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Mul
|
||||
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXMulOp::inferShapes() {
|
||||
bool ONNXMulOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Div
|
||||
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXDivOp::inferShapes() {
|
||||
bool ONNXDivOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Sub
|
||||
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSubOp::inferShapes() {
|
||||
bool ONNXSubOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// And
|
||||
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXAndOp::inferShapes() {
|
||||
bool ONNXAndOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Or
|
||||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXOrOp::inferShapes() {
|
||||
bool ONNXOrOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Xor
|
||||
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXXorOp::inferShapes() {
|
||||
bool ONNXXorOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -562,10 +613,10 @@ void ONNXXorOp::inferShapes() {
|
|||
// Sum
|
||||
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXSumOp::inferShapes() {
|
||||
bool ONNXSumOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
|
@ -573,16 +624,17 @@ void ONNXSumOp::inferShapes() {
|
|||
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||
}
|
||||
getResult().setType(resultTy);
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Max
|
||||
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXMaxOp::inferShapes() {
|
||||
bool ONNXMaxOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
|
@ -590,16 +642,17 @@ void ONNXMaxOp::inferShapes() {
|
|||
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||
}
|
||||
getResult().setType(resultTy);
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Min
|
||||
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXMinOp::inferShapes() {
|
||||
bool ONNXMinOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
|
@ -607,25 +660,27 @@ void ONNXMinOp::inferShapes() {
|
|||
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||
}
|
||||
getResult().setType(resultTy);
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Identity
|
||||
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||
/// shape inference interface.
|
||||
void ONNXIdentityOp::inferShapes() {
|
||||
bool ONNXIdentityOp::inferShapes() {
|
||||
getResult().setType(getOperand().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// MatMul
|
||||
|
||||
void ONNXMatMulOp::inferShapes() {
|
||||
bool ONNXMatMulOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!A().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
|
||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||
|
@ -752,19 +807,20 @@ void ONNXMatMulOp::inferShapes() {
|
|||
}
|
||||
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Gemm
|
||||
|
||||
void ONNXGemmOp::inferShapes() {
|
||||
bool ONNXGemmOp::inferShapes() {
|
||||
bool hasBias = !C().getType().isa<NoneType>();
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!A().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>() ||
|
||||
(hasBias && !C().getType().isa<RankedTensorType>()))
|
||||
return;
|
||||
return false;
|
||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||
|
||||
|
@ -796,17 +852,18 @@ void ONNXGemmOp::inferShapes() {
|
|||
dims.emplace_back(M);
|
||||
dims.emplace_back(N);
|
||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
/// BatchNormalizationTestMode
|
||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||
bool ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!X().getType().isa<RankedTensorType>() ||
|
||||
!scale().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>() ||
|
||||
!mean().getType().isa<RankedTensorType>() ||
|
||||
!var().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
|
||||
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
||||
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
||||
|
@ -845,6 +902,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
|||
|
||||
// The output tensor of the same shape as the input.
|
||||
getResult().setType(X().getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
// TODO:
|
||||
|
@ -855,7 +913,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
|||
|
||||
// Reshape
|
||||
|
||||
void ONNXReshapeOp::inferShapes() {
|
||||
bool ONNXReshapeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape tensor is specified.
|
||||
if (!shape().getType().isa<RankedTensorType>())
|
||||
emitError("Shape tensor not ranked");
|
||||
|
@ -875,7 +933,7 @@ void ONNXReshapeOp::inferShapes() {
|
|||
|
||||
// Compute total number of elements.
|
||||
int64_t totalInputSize = 1;
|
||||
for(auto inputDim : inputTensorTy.getShape())
|
||||
for (auto inputDim : inputTensorTy.getShape())
|
||||
totalInputSize *= inputDim;
|
||||
|
||||
// Check if second argument of ReshapeOp is a constant.
|
||||
|
@ -891,7 +949,7 @@ void ONNXReshapeOp::inferShapes() {
|
|||
|
||||
// Get dims from valueAttribute.
|
||||
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
|
||||
for (int i=0; i<outputRank; ++i)
|
||||
for (int i = 0; i < outputRank; ++i)
|
||||
dims[i] = (*valueIt++).cast<IntegerAttr>().getInt();
|
||||
|
||||
if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
|
||||
|
@ -900,7 +958,7 @@ void ONNXReshapeOp::inferShapes() {
|
|||
int64_t numberOfDynamicInputs = 0;
|
||||
int64_t totalKnownDimsSize = 1;
|
||||
int64_t dynamicValueIndex = -1;
|
||||
for (int i=0; i<outputRank; ++i) {
|
||||
for (int i = 0; i < outputRank; ++i) {
|
||||
// Set output dimension.
|
||||
if (dims[i] == 0)
|
||||
dims[i] = inputTensorTy.getShape()[i];
|
||||
|
@ -924,16 +982,17 @@ void ONNXReshapeOp::inferShapes() {
|
|||
|
||||
getResult().setType(
|
||||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Transpose
|
||||
|
||||
void ONNXTransposeOp::inferShapes() {
|
||||
bool ONNXTransposeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
|
@ -951,62 +1010,67 @@ void ONNXTransposeOp::inferShapes() {
|
|||
}
|
||||
|
||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceMax
|
||||
|
||||
void ONNXReduceMaxOp::inferShapes() {
|
||||
bool ONNXReduceMaxOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceMin
|
||||
|
||||
void ONNXReduceMinOp::inferShapes() {
|
||||
bool ONNXReduceMinOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceProd
|
||||
|
||||
void ONNXReduceProdOp::inferShapes() {
|
||||
bool ONNXReduceProdOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// ReduceSum
|
||||
|
||||
void ONNXReduceSumOp::inferShapes() {
|
||||
bool ONNXReduceSumOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1022,7 +1086,7 @@ void ONNXReduceSumOp::inferShapes() {
|
|||
// - kernelShape: inferred from weight matrix if not defined by user;
|
||||
// - pads: set to proper value, 0 if not defined by user.
|
||||
|
||||
void ONNXConvOp::inferShapes() {
|
||||
bool ONNXConvOp::inferShapes() {
|
||||
// Generic shape for data input X, weight tensor W, and optional bias B
|
||||
// X: (N x C x D1 x D2 ... x Dn)
|
||||
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||
|
@ -1034,7 +1098,7 @@ void ONNXConvOp::inferShapes() {
|
|||
if (!X().getType().isa<RankedTensorType>() ||
|
||||
!W().getType().isa<RankedTensorType>() ||
|
||||
(hasBias && !B().getType().isa<RankedTensorType>()))
|
||||
return;
|
||||
return false;
|
||||
|
||||
auto xTy = X().getType().cast<RankedTensorType>();
|
||||
auto xShape = xTy.getShape();
|
||||
|
@ -1043,12 +1107,16 @@ void ONNXConvOp::inferShapes() {
|
|||
auto builder = mlir::Builder(this->getContext());
|
||||
|
||||
// Lowest supported convolution is a one dimensional convolution.
|
||||
if (xShape.size() < 3)
|
||||
if (xShape.size() < 3) {
|
||||
emitError("Data input shape must be at least (NxCxD1)");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that shape of weight and data have same length.
|
||||
if (xShape.size() != weightShape.size())
|
||||
if (xShape.size() != weightShape.size()) {
|
||||
emitError("Weight size not compatible with data size");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Group is a required attribute and should have default value of 1.
|
||||
int64_t group = ONNXConvOp::group().getSExtValue();
|
||||
|
@ -1059,17 +1127,23 @@ void ONNXConvOp::inferShapes() {
|
|||
|
||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
||||
xShape[1] != (weightShape[1] * group))
|
||||
xShape[1] != (weightShape[1] * group)) {
|
||||
emitError("Channel dimension mismatch");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check the size of bias.
|
||||
if (hasBias) {
|
||||
auto bTx = B().getType().cast<RankedTensorType>();
|
||||
auto bShape = bTx.getShape();
|
||||
if (bShape.size() != 1)
|
||||
if (bShape.size() != 1) {
|
||||
emitError("bias should be one dimensional");
|
||||
if (bShape[0] != weightShape[0])
|
||||
return false;
|
||||
}
|
||||
if (bShape[0] != weightShape[0]) {
|
||||
emitError("bias should have same dimensions as weight's first dimension");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Note: the value of the group attribut only impacts the way the
|
||||
|
@ -1083,12 +1157,16 @@ void ONNXConvOp::inferShapes() {
|
|||
// argument.
|
||||
auto kernelShape = kernel_shape();
|
||||
if (kernelShape.hasValue()) {
|
||||
if (ArrayAttrSize(kernelShape) != spatialRank)
|
||||
if (ArrayAttrSize(kernelShape) != spatialRank) {
|
||||
emitError("kernel_shape length incompatible with spatial dimensions");
|
||||
return false;
|
||||
}
|
||||
// Have the right number of values, check them.
|
||||
for (int i = 0; i < spatialRank; ++i)
|
||||
if (ArrayAttrIntVal(kernelShape, i) < 1)
|
||||
if (ArrayAttrIntVal(kernelShape, i) < 1) {
|
||||
emitError("bad kernel_shape value");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// Deduce shape from weight input.
|
||||
SmallVector<int64_t, 2> defaultVals;
|
||||
|
@ -1119,6 +1197,7 @@ void ONNXConvOp::inferShapes() {
|
|||
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt);
|
||||
|
||||
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1129,10 +1208,10 @@ void ONNXConvOp::inferShapes() {
|
|||
// - strides: set to 1 if not defined by user;
|
||||
// - pads: set to proper value, 0 if not defined by user.
|
||||
|
||||
void ONNXAveragePoolOp::inferShapes() {
|
||||
bool ONNXAveragePoolOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!X().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
|
||||
// Get shape of input.
|
||||
auto xTy = X().getType().cast<RankedTensorType>();
|
||||
|
@ -1163,6 +1242,7 @@ void ONNXAveragePoolOp::inferShapes() {
|
|||
llvm::None, ceilMode);
|
||||
|
||||
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1173,10 +1253,10 @@ void ONNXAveragePoolOp::inferShapes() {
|
|||
// - dilations, strides: set to 1 if not defined by user;
|
||||
// - pads: set to proper value, 0 if not defined by user.
|
||||
|
||||
void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||
bool ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!X().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
|
||||
// Get shape of input.
|
||||
auto xTy = X().getType().cast<RankedTensorType>();
|
||||
|
@ -1211,6 +1291,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
|||
dilationsOpt, ceilMode);
|
||||
|
||||
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1246,24 +1327,26 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
|||
|
||||
// PadConstantPad
|
||||
|
||||
void ONNXPadConstantPadOp::inferShapes() {
|
||||
bool ONNXPadConstantPadOp::inferShapes() {
|
||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||
if (outputType) {
|
||||
getResult().setType(outputType);
|
||||
return true;
|
||||
}
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// PadConstantValuePad
|
||||
|
||||
void ONNXPadConstantValuePadOp::inferShapes() {
|
||||
bool ONNXPadConstantValuePadOp::inferShapes() {
|
||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||
if (outputType) {
|
||||
getResult().setType(outputType);
|
||||
return true;
|
||||
}
|
||||
return;
|
||||
return false;
|
||||
}
|
||||
|
||||
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
||||
|
@ -1280,9 +1363,9 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
|||
|
||||
// Unsqueeze
|
||||
|
||||
void ONNXUnsqueezeOp::inferShapes() {
|
||||
bool ONNXUnsqueezeOp::inferShapes() {
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
return;
|
||||
return false;
|
||||
|
||||
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||
int inRank = operandTy.getRank();
|
||||
|
@ -1299,11 +1382,14 @@ void ONNXUnsqueezeOp::inferShapes() {
|
|||
assert(axis >= -outRank && axis <= outRank - 1);
|
||||
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||
axes.emplace_back(axis);
|
||||
else
|
||||
else {
|
||||
emitError("Duplicated axes");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
emitError("Axes attribute is required");
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 4> dims;
|
||||
|
@ -1315,12 +1401,13 @@ void ONNXUnsqueezeOp::inferShapes() {
|
|||
}
|
||||
}
|
||||
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constant
|
||||
|
||||
void ONNXConstantOp::inferShapes() {
|
||||
bool ONNXConstantOp::inferShapes() {
|
||||
if ((sparse_value().hasValue() && value().hasValue()) ||
|
||||
(!sparse_value().hasValue() && !value().hasValue()))
|
||||
emitError("Require exactly one of the two attributes, either value or "
|
||||
|
@ -1332,6 +1419,7 @@ void ONNXConstantOp::inferShapes() {
|
|||
else
|
||||
valAttr = valueAttr().cast<DenseElementsAttr>();
|
||||
getResult().setType(valAttr.getType());
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
|||
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"void", "inferShapes">
|
||||
"bool", "inferShapes">
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -35,7 +35,11 @@ public:
|
|||
f.walk([&](mlir::Operation *op) {
|
||||
if (returnsDynamicShape(op)) {
|
||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||
shape_op.inferShapes();
|
||||
if (!shape_op.inferShapes()) {
|
||||
op->emitError("unable to infer shape of operation without shape "
|
||||
"inference interface");
|
||||
return signalPassFailure();
|
||||
}
|
||||
} else {
|
||||
op->emitError("unable to infer shape of operation without shape "
|
||||
"inference interface");
|
||||
|
|
Loading…
Reference in New Issue