Handle error shape inference (#47)

* add return to inferShape

* ran clang-format

* minor changes according to review

* fix format
This commit is contained in:
chentong319 2020-03-30 11:22:55 -04:00 committed by GitHub
parent 867406191f
commit 55cbe316fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 183 additions and 91 deletions

View File

@ -151,9 +151,8 @@ static void processConvStrideParam(T *op, Optional<ArrayAttr> kernelShape) {
// Support function that computes default values for pads. // Support function that computes default values for pads.
// //
template <class T> template <class T>
static void processConvPadParam(T *op, static void processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> inputShape, Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
Optional<ArrayAttr> stridesOpt,
Optional<ArrayAttr> dilationsOpt = llvm::None) { Optional<ArrayAttr> dilationsOpt = llvm::None) {
auto builder = mlir::Builder(op->getContext()); auto builder = mlir::Builder(op->getContext());
@ -256,7 +255,7 @@ static void processConvTypeParams(T *op, Value inputOperand) {
processConvDilationParam<T>(op, kernelShape); processConvDilationParam<T>(op, kernelShape);
auto dilationsOpt = op->dilations(); auto dilationsOpt = op->dilations();
// Strides. // Strides.
processConvStrideParam<T>(op, kernelShape); processConvStrideParam<T>(op, kernelShape);
auto stridesOpt = op->strides(); auto stridesOpt = op->strides();
@ -341,219 +340,271 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
// Exp // Exp
/// Infer the output shape of the ONNXExpOp. This method is required by the /// Infer the output shape of the ONNXExpOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXExpOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Tanh // Tanh
/// Infer the output shape of the ONNXTanhOp. This method is required by the /// Infer the output shape of the ONNXTanhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXTanhOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sinh // Sinh
/// Infer the output shape of the ONNXSinhOp. This method is required by the /// Infer the output shape of the ONNXSinhOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXSinhOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cosh // Cosh
/// Infer the output shape of the ONNXCoshOp. This method is required by the /// Infer the output shape of the ONNXCoshOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXCoshOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Cos // Cos
/// Infer the output shape of the ONNXCosOp. This method is required by the /// Infer the output shape of the ONNXCosOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXCosOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Log // Log
/// Infer the output shape of the ONNXLogOp. This method is required by the /// Infer the output shape of the ONNXLogOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXLogOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// HardSigmoid // HardSigmoid
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by /// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXHardSigmoidOp::inferShapes() { bool ONNXHardSigmoidOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sigmoid // Sigmoid
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the /// Infer the output shape of the ONNXSigmoidOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSigmoidOp::inferShapes() { bool ONNXSigmoidOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Elu // Elu
/// Infer the output shape of the ONNXEluOp. This method is required by the /// Infer the output shape of the ONNXEluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXEluOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Relu // Relu
/// Infer the output shape of the ONNXReluOp. This method is required by the /// Infer the output shape of the ONNXReluOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXReluOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LeakyRelu // LeakyRelu
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by /// Infer the output shape of the ONNXLeakyReluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXLeakyReluOp::inferShapes() { bool ONNXLeakyReluOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Selu // Selu
/// Infer the output shape of the ONNXSeluOp. This method is required by /// Infer the output shape of the ONNXSeluOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXSeluOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Reciprocal // Reciprocal
/// Infer the output shape of the ONNXReciprocalOp. This method is required by /// Infer the output shape of the ONNXReciprocalOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXReciprocalOp::inferShapes() { bool ONNXReciprocalOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Softmax // Softmax
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by /// Infer the output shape of the ONNXSoftmaxOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSoftmaxOp::inferShapes() { bool ONNXSoftmaxOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Softplus // Softplus
/// Infer the output shape of the ONNXSoftplusOp. This method is required by /// Infer the output shape of the ONNXSoftplusOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSoftplusOp::inferShapes() { bool ONNXSoftplusOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Softsign // Softsign
/// Infer the output shape of the ONNXSoftsignOp. This method is required by /// Infer the output shape of the ONNXSoftsignOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSoftsignOp::inferShapes() { bool ONNXSoftsignOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sqrt // Sqrt
/// Infer the output shape of the ONNXSqrtOp. This method is required by /// Infer the output shape of the ONNXSqrtOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXSqrtOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sign // Sign
/// Infer the output shape of the ONNXSignOp. This method is required by /// Infer the output shape of the ONNXSignOp. This method is required by
/// the shape inference interface. /// the shape inference interface.
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXSignOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Abs // Abs
/// Infer the output shape of the ONNXAbsOp. This method is required by the /// Infer the output shape of the ONNXAbsOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); } bool ONNXAbsOp::inferShapes() {
getResult().setType(getOperand().getType());
return true;
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Add // Add
/// Infer the output shape of the ONNXAddOp. This method is required by the /// Infer the output shape of the ONNXAddOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXAddOp::inferShapes() { bool ONNXAddOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>()) {
return; emitError("ONNXAddOp inferShapes failed");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Mul // Mul
/// Infer the output shape of the ONNXMulOp. This method is required by the /// Infer the output shape of the ONNXMulOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXMulOp::inferShapes() { bool ONNXMulOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Div // Div
/// Infer the output shape of the ONNXDivOp. This method is required by the /// Infer the output shape of the ONNXDivOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXDivOp::inferShapes() { bool ONNXDivOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Sub // Sub
/// Infer the output shape of the ONNXSubOp. This method is required by the /// Infer the output shape of the ONNXSubOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSubOp::inferShapes() { bool ONNXSubOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// And // And
/// Infer the output shape of the ONNXAndOp. This method is required by the /// Infer the output shape of the ONNXAndOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXAndOp::inferShapes() { bool ONNXAndOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Or // Or
/// Infer the output shape of the ONNXOrOp. This method is required by the /// Infer the output shape of the ONNXOrOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXOrOp::inferShapes() { bool ONNXOrOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Xor // Xor
/// Infer the output shape of the ONNXXorOp. This method is required by the /// Infer the output shape of the ONNXXorOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXXorOp::inferShapes() { bool ONNXXorOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>())
return; return false;
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -562,10 +613,10 @@ void ONNXXorOp::inferShapes() {
// Sum // Sum
/// Infer the output shape of the ONNXSumOp. This method is required by the /// Infer the output shape of the ONNXSumOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXSumOp::inferShapes() { bool ONNXSumOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>())
return; return false;
} }
Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
@ -573,16 +624,17 @@ void ONNXSumOp::inferShapes() {
resultTy = getBroadcastedType(resultTy, nextTy); resultTy = getBroadcastedType(resultTy, nextTy);
} }
getResult().setType(resultTy); getResult().setType(resultTy);
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Max // Max
/// Infer the output shape of the ONNXMaxOp. This method is required by the /// Infer the output shape of the ONNXMaxOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXMaxOp::inferShapes() { bool ONNXMaxOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>())
return; return false;
} }
Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
@ -590,16 +642,17 @@ void ONNXMaxOp::inferShapes() {
resultTy = getBroadcastedType(resultTy, nextTy); resultTy = getBroadcastedType(resultTy, nextTy);
} }
getResult().setType(resultTy); getResult().setType(resultTy);
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Min // Min
/// Infer the output shape of the ONNXMinOp. This method is required by the /// Infer the output shape of the ONNXMinOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXMinOp::inferShapes() { bool ONNXMinOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>())
return; return false;
} }
Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
@ -607,25 +660,27 @@ void ONNXMinOp::inferShapes() {
resultTy = getBroadcastedType(resultTy, nextTy); resultTy = getBroadcastedType(resultTy, nextTy);
} }
getResult().setType(resultTy); getResult().setType(resultTy);
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Identity // Identity
/// Infer the output shape of the ONNXIdentityOp. This method is required by the /// Infer the output shape of the ONNXIdentityOp. This method is required by the
/// shape inference interface. /// shape inference interface.
void ONNXIdentityOp::inferShapes() { bool ONNXIdentityOp::inferShapes() {
getResult().setType(getOperand().getType()); getResult().setType(getOperand().getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// MatMul // MatMul
void ONNXMatMulOp::inferShapes() { bool ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>()) !B().getType().isa<RankedTensorType>())
return; return false;
auto lhsTy = A().getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
@ -752,19 +807,20 @@ void ONNXMatMulOp::inferShapes() {
} }
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Gemm // Gemm
void ONNXGemmOp::inferShapes() { bool ONNXGemmOp::inferShapes() {
bool hasBias = !C().getType().isa<NoneType>(); bool hasBias = !C().getType().isa<NoneType>();
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
(hasBias && !C().getType().isa<RankedTensorType>())) (hasBias && !C().getType().isa<RankedTensorType>()))
return; return false;
auto lhsTy = A().getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
@ -796,17 +852,18 @@ void ONNXGemmOp::inferShapes() {
dims.emplace_back(M); dims.emplace_back(M);
dims.emplace_back(N); dims.emplace_back(N);
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType())); getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
return true;
} }
/// BatchNormalizationTestMode /// BatchNormalizationTestMode
void ONNXBatchNormalizationTestModeOp::inferShapes() { bool ONNXBatchNormalizationTestModeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!scale().getType().isa<RankedTensorType>() || !scale().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
!mean().getType().isa<RankedTensorType>() || !mean().getType().isa<RankedTensorType>() ||
!var().getType().isa<RankedTensorType>()) !var().getType().isa<RankedTensorType>())
return; return false;
auto inputTensorTy = X().getType().cast<RankedTensorType>(); auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scaleTensorTy = scale().getType().cast<RankedTensorType>(); auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
@ -845,6 +902,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
// The output tensor of the same shape as the input. // The output tensor of the same shape as the input.
getResult().setType(X().getType()); getResult().setType(X().getType());
return true;
} }
// TODO: // TODO:
@ -855,7 +913,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
// Reshape // Reshape
void ONNXReshapeOp::inferShapes() { bool ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified. // Cannot infer shape if no shape tensor is specified.
if (!shape().getType().isa<RankedTensorType>()) if (!shape().getType().isa<RankedTensorType>())
emitError("Shape tensor not ranked"); emitError("Shape tensor not ranked");
@ -875,7 +933,7 @@ void ONNXReshapeOp::inferShapes() {
// Compute total number of elements. // Compute total number of elements.
int64_t totalInputSize = 1; int64_t totalInputSize = 1;
for(auto inputDim : inputTensorTy.getShape()) for (auto inputDim : inputTensorTy.getShape())
totalInputSize *= inputDim; totalInputSize *= inputDim;
// Check if second argument of ReshapeOp is a constant. // Check if second argument of ReshapeOp is a constant.
@ -891,7 +949,7 @@ void ONNXReshapeOp::inferShapes() {
// Get dims from valueAttribute. // Get dims from valueAttribute.
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin(); auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
for (int i=0; i<outputRank; ++i) for (int i = 0; i < outputRank; ++i)
dims[i] = (*valueIt++).cast<IntegerAttr>().getInt(); dims[i] = (*valueIt++).cast<IntegerAttr>().getInt();
if (valueIt != valueAttribute.getValues<IntegerAttr>().end()) if (valueIt != valueAttribute.getValues<IntegerAttr>().end())
@ -900,7 +958,7 @@ void ONNXReshapeOp::inferShapes() {
int64_t numberOfDynamicInputs = 0; int64_t numberOfDynamicInputs = 0;
int64_t totalKnownDimsSize = 1; int64_t totalKnownDimsSize = 1;
int64_t dynamicValueIndex = -1; int64_t dynamicValueIndex = -1;
for (int i=0; i<outputRank; ++i) { for (int i = 0; i < outputRank; ++i) {
// Set output dimension. // Set output dimension.
if (dims[i] == 0) if (dims[i] == 0)
dims[i] = inputTensorTy.getShape()[i]; dims[i] = inputTensorTy.getShape()[i];
@ -924,16 +982,17 @@ void ONNXReshapeOp::inferShapes() {
getResult().setType( getResult().setType(
RankedTensorType::get(dims, inputTensorTy.getElementType())); RankedTensorType::get(dims, inputTensorTy.getElementType()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Transpose // Transpose
void ONNXTransposeOp::inferShapes() { bool ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>())
return; return false;
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
@ -951,62 +1010,67 @@ void ONNXTransposeOp::inferShapes() {
} }
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType())); getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceMax // ReduceMax
void ONNXReduceMaxOp::inferShapes() { bool ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Shape tensor not ranked");
return; return false;
} }
auto operandTy = getOperand().getType().cast<RankedTensorType>(); auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceMin // ReduceMin
void ONNXReduceMinOp::inferShapes() { bool ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Shape tensor not ranked");
return; return false;
} }
auto operandTy = getOperand().getType().cast<RankedTensorType>(); auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceProd // ReduceProd
void ONNXReduceProdOp::inferShapes() { bool ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Shape tensor not ranked");
return; return false;
} }
auto operandTy = getOperand().getType().cast<RankedTensorType>(); auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ReduceSum // ReduceSum
void ONNXReduceSumOp::inferShapes() { bool ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Shape tensor not ranked");
return; return false;
} }
auto operandTy = getOperand().getType().cast<RankedTensorType>(); auto operandTy = getOperand().getType().cast<RankedTensorType>();
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims())); getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1022,7 +1086,7 @@ void ONNXReduceSumOp::inferShapes() {
// - kernelShape: inferred from weight matrix if not defined by user; // - kernelShape: inferred from weight matrix if not defined by user;
// - pads: set to proper value, 0 if not defined by user. // - pads: set to proper value, 0 if not defined by user.
void ONNXConvOp::inferShapes() { bool ONNXConvOp::inferShapes() {
// Generic shape for data input X, weight tensor W, and optional bias B // Generic shape for data input X, weight tensor W, and optional bias B
// X: (N x C x D1 x D2 ... x Dn) // X: (N x C x D1 x D2 ... x Dn)
// W: (M x C/group x k1 x k2 x ... x kn) // W: (M x C/group x k1 x k2 x ... x kn)
@ -1034,7 +1098,7 @@ void ONNXConvOp::inferShapes() {
if (!X().getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>() || !W().getType().isa<RankedTensorType>() ||
(hasBias && !B().getType().isa<RankedTensorType>())) (hasBias && !B().getType().isa<RankedTensorType>()))
return; return false;
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape(); auto xShape = xTy.getShape();
@ -1043,12 +1107,16 @@ void ONNXConvOp::inferShapes() {
auto builder = mlir::Builder(this->getContext()); auto builder = mlir::Builder(this->getContext());
// Lowest supported convolution is a one dimensional convolution. // Lowest supported convolution is a one dimensional convolution.
if (xShape.size() < 3) if (xShape.size() < 3) {
emitError("Data input shape must be at least (NxCxD1)"); emitError("Data input shape must be at least (NxCxD1)");
return false;
}
// Check that shape of weight and data have same length. // Check that shape of weight and data have same length.
if (xShape.size() != weightShape.size()) if (xShape.size() != weightShape.size()) {
emitError("Weight size not compatible with data size"); emitError("Weight size not compatible with data size");
return false;
}
// Group is a required attribute and should have default value of 1. // Group is a required attribute and should have default value of 1.
int64_t group = ONNXConvOp::group().getSExtValue(); int64_t group = ONNXConvOp::group().getSExtValue();
@ -1059,17 +1127,23 @@ void ONNXConvOp::inferShapes() {
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds. // Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
if (xShape[1] != -1 && weightShape[1] != -1 && if (xShape[1] != -1 && weightShape[1] != -1 &&
xShape[1] != (weightShape[1] * group)) xShape[1] != (weightShape[1] * group)) {
emitError("Channel dimension mismatch"); emitError("Channel dimension mismatch");
return false;
}
// Check the size of bias. // Check the size of bias.
if (hasBias) { if (hasBias) {
auto bTx = B().getType().cast<RankedTensorType>(); auto bTx = B().getType().cast<RankedTensorType>();
auto bShape = bTx.getShape(); auto bShape = bTx.getShape();
if (bShape.size() != 1) if (bShape.size() != 1) {
emitError("bias should be one dimensional"); emitError("bias should be one dimensional");
if (bShape[0] != weightShape[0]) return false;
}
if (bShape[0] != weightShape[0]) {
emitError("bias should have same dimensions as weight's first dimension"); emitError("bias should have same dimensions as weight's first dimension");
return false;
}
} }
// Note: the value of the group attribut only impacts the way the // Note: the value of the group attribut only impacts the way the
@ -1083,12 +1157,16 @@ void ONNXConvOp::inferShapes() {
// argument. // argument.
auto kernelShape = kernel_shape(); auto kernelShape = kernel_shape();
if (kernelShape.hasValue()) { if (kernelShape.hasValue()) {
if (ArrayAttrSize(kernelShape) != spatialRank) if (ArrayAttrSize(kernelShape) != spatialRank) {
emitError("kernel_shape length incompatible with spatial dimensions"); emitError("kernel_shape length incompatible with spatial dimensions");
return false;
}
// Have the right number of values, check them. // Have the right number of values, check them.
for (int i = 0; i < spatialRank; ++i) for (int i = 0; i < spatialRank; ++i)
if (ArrayAttrIntVal(kernelShape, i) < 1) if (ArrayAttrIntVal(kernelShape, i) < 1) {
emitError("bad kernel_shape value"); emitError("bad kernel_shape value");
return false;
}
} else { } else {
// Deduce shape from weight input. // Deduce shape from weight input.
SmallVector<int64_t, 2> defaultVals; SmallVector<int64_t, 2> defaultVals;
@ -1119,6 +1197,7 @@ void ONNXConvOp::inferShapes() {
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt); &outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1129,10 +1208,10 @@ void ONNXConvOp::inferShapes() {
// - strides: set to 1 if not defined by user; // - strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user. // - pads: set to proper value, 0 if not defined by user.
void ONNXAveragePoolOp::inferShapes() { bool ONNXAveragePoolOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>()) if (!X().getType().isa<RankedTensorType>())
return; return false;
// Get shape of input. // Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
@ -1163,6 +1242,7 @@ void ONNXAveragePoolOp::inferShapes() {
llvm::None, ceilMode); llvm::None, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1173,10 +1253,10 @@ void ONNXAveragePoolOp::inferShapes() {
// - dilations, strides: set to 1 if not defined by user; // - dilations, strides: set to 1 if not defined by user;
// - pads: set to proper value, 0 if not defined by user. // - pads: set to proper value, 0 if not defined by user.
void ONNXMaxPoolSingleOutOp::inferShapes() { bool ONNXMaxPoolSingleOutOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>()) if (!X().getType().isa<RankedTensorType>())
return; return false;
// Get shape of input. // Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
@ -1211,6 +1291,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
dilationsOpt, ceilMode); dilationsOpt, ceilMode);
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType())); getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1246,24 +1327,26 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
// PadConstantPad // PadConstantPad
void ONNXPadConstantPadOp::inferShapes() { bool ONNXPadConstantPadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads()); auto outputType = padShapeInferenceHelper(data(), pads());
if (outputType) { if (outputType) {
getResult().setType(outputType); getResult().setType(outputType);
return true;
} }
return; return false;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// PadConstantValuePad // PadConstantValuePad
void ONNXPadConstantValuePadOp::inferShapes() { bool ONNXPadConstantValuePadOp::inferShapes() {
auto outputType = padShapeInferenceHelper(data(), pads()); auto outputType = padShapeInferenceHelper(data(), pads());
if (outputType) { if (outputType) {
getResult().setType(outputType); getResult().setType(outputType);
return true;
} }
return; return false;
} }
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state, void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
@ -1280,9 +1363,9 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
// Unsqueeze // Unsqueeze
void ONNXUnsqueezeOp::inferShapes() { bool ONNXUnsqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>())
return; return false;
auto operandTy = data().getType().cast<RankedTensorType>(); auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank(); int inRank = operandTy.getRank();
@ -1299,11 +1382,14 @@ void ONNXUnsqueezeOp::inferShapes() {
assert(axis >= -outRank && axis <= outRank - 1); assert(axis >= -outRank && axis <= outRank - 1);
if (std::find(axes.begin(), axes.end(), axis) == axes.end()) if (std::find(axes.begin(), axes.end(), axis) == axes.end())
axes.emplace_back(axis); axes.emplace_back(axis);
else else {
emitError("Duplicated axes"); emitError("Duplicated axes");
return false;
}
} }
} else { } else {
emitError("Axes attribute is required"); emitError("Axes attribute is required");
return false;
} }
SmallVector<int64_t, 4> dims; SmallVector<int64_t, 4> dims;
@ -1315,12 +1401,13 @@ void ONNXUnsqueezeOp::inferShapes() {
} }
} }
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType())); getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Constant // Constant
void ONNXConstantOp::inferShapes() { bool ONNXConstantOp::inferShapes() {
if ((sparse_value().hasValue() && value().hasValue()) || if ((sparse_value().hasValue() && value().hasValue()) ||
(!sparse_value().hasValue() && !value().hasValue())) (!sparse_value().hasValue() && !value().hasValue()))
emitError("Require exactly one of the two attributes, either value or " emitError("Require exactly one of the two attributes, either value or "
@ -1332,6 +1419,7 @@ void ONNXConstantOp::inferShapes() {
else else
valAttr = valueAttr().cast<DenseElementsAttr>(); valAttr = valueAttr().cast<DenseElementsAttr>();
getResult().setType(valAttr.getType()); getResult().setType(valAttr.getType());
return true;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let methods = [ let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.", InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes"> "bool", "inferShapes">
]; ];
} }

View File

@ -35,7 +35,11 @@ public:
f.walk([&](mlir::Operation *op) { f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op)) { if (returnsDynamicShape(op)) {
if (auto shape_op = dyn_cast<ShapeInference>(op)) { if (auto shape_op = dyn_cast<ShapeInference>(op)) {
shape_op.inferShapes(); if (!shape_op.inferShapes()) {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
} else { } else {
op->emitError("unable to infer shape of operation without shape " op->emitError("unable to infer shape of operation without shape "
"inference interface"); "inference interface");