Handle error shape inference (#47)
* add return to inferShape * ran clang-format * minor changes according to review * fix format
This commit is contained in:
parent
867406191f
commit
55cbe316fd
|
@ -151,9 +151,8 @@ static void processConvStrideParam(T *op, Optional<ArrayAttr> kernelShape) {
|
||||||
// Support function that computes default values for pads.
|
// Support function that computes default values for pads.
|
||||||
//
|
//
|
||||||
template <class T>
|
template <class T>
|
||||||
static void processConvPadParam(T *op,
|
static void processConvPadParam(T *op, ArrayRef<int64_t> inputShape,
|
||||||
ArrayRef<int64_t> inputShape, Optional<ArrayAttr> kernelShape,
|
Optional<ArrayAttr> kernelShape, Optional<ArrayAttr> stridesOpt,
|
||||||
Optional<ArrayAttr> stridesOpt,
|
|
||||||
Optional<ArrayAttr> dilationsOpt = llvm::None) {
|
Optional<ArrayAttr> dilationsOpt = llvm::None) {
|
||||||
auto builder = mlir::Builder(op->getContext());
|
auto builder = mlir::Builder(op->getContext());
|
||||||
|
|
||||||
|
@ -341,219 +340,271 @@ ONNXEntryPointOp ONNXEntryPointOp::create(mlir::Location location,
|
||||||
// Exp
|
// Exp
|
||||||
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
/// Infer the output shape of the ONNXExpOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXExpOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXExpOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tanh
|
// Tanh
|
||||||
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
/// Infer the output shape of the ONNXTanhOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXTanhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXTanhOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sinh
|
// Sinh
|
||||||
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
/// Infer the output shape of the ONNXSinhOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXSinhOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXSinhOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cosh
|
// Cosh
|
||||||
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
/// Infer the output shape of the ONNXCoshOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXCoshOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXCoshOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Cos
|
// Cos
|
||||||
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
/// Infer the output shape of the ONNXCosOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXCosOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXCosOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Log
|
// Log
|
||||||
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
/// Infer the output shape of the ONNXLogOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXLogOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXLogOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// HardSigmoid
|
// HardSigmoid
|
||||||
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
/// Infer the output shape of the ONNXHardSigmoidOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXHardSigmoidOp::inferShapes() {
|
bool ONNXHardSigmoidOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sigmoid
|
// Sigmoid
|
||||||
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
/// Infer the output shape of the ONNXSigmoidOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXSigmoidOp::inferShapes() {
|
bool ONNXSigmoidOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Elu
|
// Elu
|
||||||
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
/// Infer the output shape of the ONNXEluOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXEluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXEluOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Relu
|
// Relu
|
||||||
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
/// Infer the output shape of the ONNXReluOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXReluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXReluOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// LeakyRelu
|
// LeakyRelu
|
||||||
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
/// Infer the output shape of the ONNXLeakyReluOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXLeakyReluOp::inferShapes() {
|
bool ONNXLeakyReluOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Selu
|
// Selu
|
||||||
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
/// Infer the output shape of the ONNXSeluOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSeluOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXSeluOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Reciprocal
|
// Reciprocal
|
||||||
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
/// Infer the output shape of the ONNXReciprocalOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXReciprocalOp::inferShapes() {
|
bool ONNXReciprocalOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Softmax
|
// Softmax
|
||||||
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
/// Infer the output shape of the ONNXSoftmaxOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSoftmaxOp::inferShapes() {
|
bool ONNXSoftmaxOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Softplus
|
// Softplus
|
||||||
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
/// Infer the output shape of the ONNXSoftplusOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSoftplusOp::inferShapes() {
|
bool ONNXSoftplusOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Softsign
|
// Softsign
|
||||||
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
/// Infer the output shape of the ONNXSoftsignOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSoftsignOp::inferShapes() {
|
bool ONNXSoftsignOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sqrt
|
// Sqrt
|
||||||
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
/// Infer the output shape of the ONNXSqrtOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSqrtOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXSqrtOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sign
|
// Sign
|
||||||
/// Infer the output shape of the ONNXSignOp. This method is required by
|
/// Infer the output shape of the ONNXSignOp. This method is required by
|
||||||
/// the shape inference interface.
|
/// the shape inference interface.
|
||||||
void ONNXSignOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXSignOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Abs
|
// Abs
|
||||||
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
/// Infer the output shape of the ONNXAbsOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXAbsOp::inferShapes() { getResult().setType(getOperand().getType()); }
|
bool ONNXAbsOp::inferShapes() {
|
||||||
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Add
|
// Add
|
||||||
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
/// Infer the output shape of the ONNXAddOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXAddOp::inferShapes() {
|
bool ONNXAddOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
return;
|
emitError("ONNXAddOp inferShapes failed");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Mul
|
// Mul
|
||||||
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
/// Infer the output shape of the ONNXMulOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXMulOp::inferShapes() {
|
bool ONNXMulOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Div
|
// Div
|
||||||
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
/// Infer the output shape of the ONNXDivOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXDivOp::inferShapes() {
|
bool ONNXDivOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Sub
|
// Sub
|
||||||
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
/// Infer the output shape of the ONNXSubOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXSubOp::inferShapes() {
|
bool ONNXSubOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// And
|
// And
|
||||||
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
/// Infer the output shape of the ONNXAndOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXAndOp::inferShapes() {
|
bool ONNXAndOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Or
|
// Or
|
||||||
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
/// Infer the output shape of the ONNXOrOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXOrOp::inferShapes() {
|
bool ONNXOrOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Xor
|
// Xor
|
||||||
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
/// Infer the output shape of the ONNXXorOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXXorOp::inferShapes() {
|
bool ONNXXorOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -562,10 +613,10 @@ void ONNXXorOp::inferShapes() {
|
||||||
// Sum
|
// Sum
|
||||||
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
/// Infer the output shape of the ONNXSumOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXSumOp::inferShapes() {
|
bool ONNXSumOp::inferShapes() {
|
||||||
for (int i = 0; i < getNumOperands(); ++i) {
|
for (int i = 0; i < getNumOperands(); ++i) {
|
||||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
for (int i = 1; i < getNumOperands(); ++i) {
|
for (int i = 1; i < getNumOperands(); ++i) {
|
||||||
|
@ -573,16 +624,17 @@ void ONNXSumOp::inferShapes() {
|
||||||
resultTy = getBroadcastedType(resultTy, nextTy);
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||||
}
|
}
|
||||||
getResult().setType(resultTy);
|
getResult().setType(resultTy);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Max
|
// Max
|
||||||
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
/// Infer the output shape of the ONNXMaxOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXMaxOp::inferShapes() {
|
bool ONNXMaxOp::inferShapes() {
|
||||||
for (int i = 0; i < getNumOperands(); ++i) {
|
for (int i = 0; i < getNumOperands(); ++i) {
|
||||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
for (int i = 1; i < getNumOperands(); ++i) {
|
for (int i = 1; i < getNumOperands(); ++i) {
|
||||||
|
@ -590,16 +642,17 @@ void ONNXMaxOp::inferShapes() {
|
||||||
resultTy = getBroadcastedType(resultTy, nextTy);
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||||
}
|
}
|
||||||
getResult().setType(resultTy);
|
getResult().setType(resultTy);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Min
|
// Min
|
||||||
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
/// Infer the output shape of the ONNXMinOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXMinOp::inferShapes() {
|
bool ONNXMinOp::inferShapes() {
|
||||||
for (int i = 0; i < getNumOperands(); ++i) {
|
for (int i = 0; i < getNumOperands(); ++i) {
|
||||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
for (int i = 1; i < getNumOperands(); ++i) {
|
for (int i = 1; i < getNumOperands(); ++i) {
|
||||||
|
@ -607,25 +660,27 @@ void ONNXMinOp::inferShapes() {
|
||||||
resultTy = getBroadcastedType(resultTy, nextTy);
|
resultTy = getBroadcastedType(resultTy, nextTy);
|
||||||
}
|
}
|
||||||
getResult().setType(resultTy);
|
getResult().setType(resultTy);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Identity
|
// Identity
|
||||||
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
/// Infer the output shape of the ONNXIdentityOp. This method is required by the
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
void ONNXIdentityOp::inferShapes() {
|
bool ONNXIdentityOp::inferShapes() {
|
||||||
getResult().setType(getOperand().getType());
|
getResult().setType(getOperand().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// MatMul
|
// MatMul
|
||||||
|
|
||||||
void ONNXMatMulOp::inferShapes() {
|
bool ONNXMatMulOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!A().getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!B().getType().isa<RankedTensorType>())
|
!B().getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
|
|
||||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
|
@ -752,19 +807,20 @@ void ONNXMatMulOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Gemm
|
// Gemm
|
||||||
|
|
||||||
void ONNXGemmOp::inferShapes() {
|
bool ONNXGemmOp::inferShapes() {
|
||||||
bool hasBias = !C().getType().isa<NoneType>();
|
bool hasBias = !C().getType().isa<NoneType>();
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!A().getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!B().getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
(hasBias && !C().getType().isa<RankedTensorType>()))
|
(hasBias && !C().getType().isa<RankedTensorType>()))
|
||||||
return;
|
return false;
|
||||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -796,17 +852,18 @@ void ONNXGemmOp::inferShapes() {
|
||||||
dims.emplace_back(M);
|
dims.emplace_back(M);
|
||||||
dims.emplace_back(N);
|
dims.emplace_back(N);
|
||||||
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, lhsTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// BatchNormalizationTestMode
|
/// BatchNormalizationTestMode
|
||||||
void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
bool ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!scale().getType().isa<RankedTensorType>() ||
|
!scale().getType().isa<RankedTensorType>() ||
|
||||||
!B().getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
!mean().getType().isa<RankedTensorType>() ||
|
!mean().getType().isa<RankedTensorType>() ||
|
||||||
!var().getType().isa<RankedTensorType>())
|
!var().getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
|
|
||||||
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
||||||
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
||||||
|
@ -845,6 +902,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
|
|
||||||
// The output tensor of the same shape as the input.
|
// The output tensor of the same shape as the input.
|
||||||
getResult().setType(X().getType());
|
getResult().setType(X().getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:
|
// TODO:
|
||||||
|
@ -855,7 +913,7 @@ void ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
|
|
||||||
// Reshape
|
// Reshape
|
||||||
|
|
||||||
void ONNXReshapeOp::inferShapes() {
|
bool ONNXReshapeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape tensor is specified.
|
// Cannot infer shape if no shape tensor is specified.
|
||||||
if (!shape().getType().isa<RankedTensorType>())
|
if (!shape().getType().isa<RankedTensorType>())
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Shape tensor not ranked");
|
||||||
|
@ -924,16 +982,17 @@ void ONNXReshapeOp::inferShapes() {
|
||||||
|
|
||||||
getResult().setType(
|
getResult().setType(
|
||||||
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
RankedTensorType::get(dims, inputTensorTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Transpose
|
// Transpose
|
||||||
|
|
||||||
void ONNXTransposeOp::inferShapes() {
|
bool ONNXTransposeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!data().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
|
@ -951,62 +1010,67 @@ void ONNXTransposeOp::inferShapes() {
|
||||||
}
|
}
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, arrayTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceMax
|
// ReduceMax
|
||||||
|
|
||||||
void ONNXReduceMaxOp::inferShapes() {
|
bool ONNXReduceMaxOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceMin
|
// ReduceMin
|
||||||
|
|
||||||
void ONNXReduceMinOp::inferShapes() {
|
bool ONNXReduceMinOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceProd
|
// ReduceProd
|
||||||
|
|
||||||
void ONNXReduceProdOp::inferShapes() {
|
bool ONNXReduceProdOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// ReduceSum
|
// ReduceSum
|
||||||
|
|
||||||
void ONNXReduceSumOp::inferShapes() {
|
bool ONNXReduceSumOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Shape tensor not ranked");
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
auto operandTy = getOperand().getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
getResult().setType(getReductionOutputType(operandTy, axes(), keepdims()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1022,7 +1086,7 @@ void ONNXReduceSumOp::inferShapes() {
|
||||||
// - kernelShape: inferred from weight matrix if not defined by user;
|
// - kernelShape: inferred from weight matrix if not defined by user;
|
||||||
// - pads: set to proper value, 0 if not defined by user.
|
// - pads: set to proper value, 0 if not defined by user.
|
||||||
|
|
||||||
void ONNXConvOp::inferShapes() {
|
bool ONNXConvOp::inferShapes() {
|
||||||
// Generic shape for data input X, weight tensor W, and optional bias B
|
// Generic shape for data input X, weight tensor W, and optional bias B
|
||||||
// X: (N x C x D1 x D2 ... x Dn)
|
// X: (N x C x D1 x D2 ... x Dn)
|
||||||
// W: (M x C/group x k1 x k2 x ... x kn)
|
// W: (M x C/group x k1 x k2 x ... x kn)
|
||||||
|
@ -1034,7 +1098,7 @@ void ONNXConvOp::inferShapes() {
|
||||||
if (!X().getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!W().getType().isa<RankedTensorType>() ||
|
!W().getType().isa<RankedTensorType>() ||
|
||||||
(hasBias && !B().getType().isa<RankedTensorType>()))
|
(hasBias && !B().getType().isa<RankedTensorType>()))
|
||||||
return;
|
return false;
|
||||||
|
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
auto xShape = xTy.getShape();
|
auto xShape = xTy.getShape();
|
||||||
|
@ -1043,12 +1107,16 @@ void ONNXConvOp::inferShapes() {
|
||||||
auto builder = mlir::Builder(this->getContext());
|
auto builder = mlir::Builder(this->getContext());
|
||||||
|
|
||||||
// Lowest supported convolution is a one dimensional convolution.
|
// Lowest supported convolution is a one dimensional convolution.
|
||||||
if (xShape.size() < 3)
|
if (xShape.size() < 3) {
|
||||||
emitError("Data input shape must be at least (NxCxD1)");
|
emitError("Data input shape must be at least (NxCxD1)");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Check that shape of weight and data have same length.
|
// Check that shape of weight and data have same length.
|
||||||
if (xShape.size() != weightShape.size())
|
if (xShape.size() != weightShape.size()) {
|
||||||
emitError("Weight size not compatible with data size");
|
emitError("Weight size not compatible with data size");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Group is a required attribute and should have default value of 1.
|
// Group is a required attribute and should have default value of 1.
|
||||||
int64_t group = ONNXConvOp::group().getSExtValue();
|
int64_t group = ONNXConvOp::group().getSExtValue();
|
||||||
|
@ -1059,17 +1127,23 @@ void ONNXConvOp::inferShapes() {
|
||||||
|
|
||||||
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
// Check that the X.shape[1] == (W.shape[1] * group) == C condition holds.
|
||||||
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
if (xShape[1] != -1 && weightShape[1] != -1 &&
|
||||||
xShape[1] != (weightShape[1] * group))
|
xShape[1] != (weightShape[1] * group)) {
|
||||||
emitError("Channel dimension mismatch");
|
emitError("Channel dimension mismatch");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Check the size of bias.
|
// Check the size of bias.
|
||||||
if (hasBias) {
|
if (hasBias) {
|
||||||
auto bTx = B().getType().cast<RankedTensorType>();
|
auto bTx = B().getType().cast<RankedTensorType>();
|
||||||
auto bShape = bTx.getShape();
|
auto bShape = bTx.getShape();
|
||||||
if (bShape.size() != 1)
|
if (bShape.size() != 1) {
|
||||||
emitError("bias should be one dimensional");
|
emitError("bias should be one dimensional");
|
||||||
if (bShape[0] != weightShape[0])
|
return false;
|
||||||
|
}
|
||||||
|
if (bShape[0] != weightShape[0]) {
|
||||||
emitError("bias should have same dimensions as weight's first dimension");
|
emitError("bias should have same dimensions as weight's first dimension");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: the value of the group attribut only impacts the way the
|
// Note: the value of the group attribut only impacts the way the
|
||||||
|
@ -1083,12 +1157,16 @@ void ONNXConvOp::inferShapes() {
|
||||||
// argument.
|
// argument.
|
||||||
auto kernelShape = kernel_shape();
|
auto kernelShape = kernel_shape();
|
||||||
if (kernelShape.hasValue()) {
|
if (kernelShape.hasValue()) {
|
||||||
if (ArrayAttrSize(kernelShape) != spatialRank)
|
if (ArrayAttrSize(kernelShape) != spatialRank) {
|
||||||
emitError("kernel_shape length incompatible with spatial dimensions");
|
emitError("kernel_shape length incompatible with spatial dimensions");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// Have the right number of values, check them.
|
// Have the right number of values, check them.
|
||||||
for (int i = 0; i < spatialRank; ++i)
|
for (int i = 0; i < spatialRank; ++i)
|
||||||
if (ArrayAttrIntVal(kernelShape, i) < 1)
|
if (ArrayAttrIntVal(kernelShape, i) < 1) {
|
||||||
emitError("bad kernel_shape value");
|
emitError("bad kernel_shape value");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Deduce shape from weight input.
|
// Deduce shape from weight input.
|
||||||
SmallVector<int64_t, 2> defaultVals;
|
SmallVector<int64_t, 2> defaultVals;
|
||||||
|
@ -1119,6 +1197,7 @@ void ONNXConvOp::inferShapes() {
|
||||||
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt);
|
&outputDims, xShape, kernelShape, padsOpt, stridesOpt, dilationsOpt);
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1129,10 +1208,10 @@ void ONNXConvOp::inferShapes() {
|
||||||
// - strides: set to 1 if not defined by user;
|
// - strides: set to 1 if not defined by user;
|
||||||
// - pads: set to proper value, 0 if not defined by user.
|
// - pads: set to proper value, 0 if not defined by user.
|
||||||
|
|
||||||
void ONNXAveragePoolOp::inferShapes() {
|
bool ONNXAveragePoolOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>())
|
if (!X().getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
|
|
||||||
// Get shape of input.
|
// Get shape of input.
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
|
@ -1163,6 +1242,7 @@ void ONNXAveragePoolOp::inferShapes() {
|
||||||
llvm::None, ceilMode);
|
llvm::None, ceilMode);
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1173,10 +1253,10 @@ void ONNXAveragePoolOp::inferShapes() {
|
||||||
// - dilations, strides: set to 1 if not defined by user;
|
// - dilations, strides: set to 1 if not defined by user;
|
||||||
// - pads: set to proper value, 0 if not defined by user.
|
// - pads: set to proper value, 0 if not defined by user.
|
||||||
|
|
||||||
void ONNXMaxPoolSingleOutOp::inferShapes() {
|
bool ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>())
|
if (!X().getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
|
|
||||||
// Get shape of input.
|
// Get shape of input.
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
|
@ -1211,6 +1291,7 @@ void ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
dilationsOpt, ceilMode);
|
dilationsOpt, ceilMode);
|
||||||
|
|
||||||
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
getResult().setType(RankedTensorType::get(outputDims, xTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1246,24 +1327,26 @@ static Type padShapeInferenceHelper(Value data, ArrayAttr padsOpt) {
|
||||||
|
|
||||||
// PadConstantPad
|
// PadConstantPad
|
||||||
|
|
||||||
void ONNXPadConstantPadOp::inferShapes() {
|
bool ONNXPadConstantPadOp::inferShapes() {
|
||||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||||
if (outputType) {
|
if (outputType) {
|
||||||
getResult().setType(outputType);
|
getResult().setType(outputType);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// PadConstantValuePad
|
// PadConstantValuePad
|
||||||
|
|
||||||
void ONNXPadConstantValuePadOp::inferShapes() {
|
bool ONNXPadConstantValuePadOp::inferShapes() {
|
||||||
auto outputType = padShapeInferenceHelper(data(), pads());
|
auto outputType = padShapeInferenceHelper(data(), pads());
|
||||||
if (outputType) {
|
if (outputType) {
|
||||||
getResult().setType(outputType);
|
getResult().setType(outputType);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
||||||
|
@ -1280,9 +1363,9 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
||||||
|
|
||||||
// Unsqueeze
|
// Unsqueeze
|
||||||
|
|
||||||
void ONNXUnsqueezeOp::inferShapes() {
|
bool ONNXUnsqueezeOp::inferShapes() {
|
||||||
if (!data().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>())
|
||||||
return;
|
return false;
|
||||||
|
|
||||||
auto operandTy = data().getType().cast<RankedTensorType>();
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||||
int inRank = operandTy.getRank();
|
int inRank = operandTy.getRank();
|
||||||
|
@ -1299,11 +1382,14 @@ void ONNXUnsqueezeOp::inferShapes() {
|
||||||
assert(axis >= -outRank && axis <= outRank - 1);
|
assert(axis >= -outRank && axis <= outRank - 1);
|
||||||
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
if (std::find(axes.begin(), axes.end(), axis) == axes.end())
|
||||||
axes.emplace_back(axis);
|
axes.emplace_back(axis);
|
||||||
else
|
else {
|
||||||
emitError("Duplicated axes");
|
emitError("Duplicated axes");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
emitError("Axes attribute is required");
|
emitError("Axes attribute is required");
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
SmallVector<int64_t, 4> dims;
|
SmallVector<int64_t, 4> dims;
|
||||||
|
@ -1315,12 +1401,13 @@ void ONNXUnsqueezeOp::inferShapes() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
getResult().setType(RankedTensorType::get(dims, operandTy.getElementType()));
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Constant
|
// Constant
|
||||||
|
|
||||||
void ONNXConstantOp::inferShapes() {
|
bool ONNXConstantOp::inferShapes() {
|
||||||
if ((sparse_value().hasValue() && value().hasValue()) ||
|
if ((sparse_value().hasValue() && value().hasValue()) ||
|
||||||
(!sparse_value().hasValue() && !value().hasValue()))
|
(!sparse_value().hasValue() && !value().hasValue()))
|
||||||
emitError("Require exactly one of the two attributes, either value or "
|
emitError("Require exactly one of the two attributes, either value or "
|
||||||
|
@ -1332,6 +1419,7 @@ void ONNXConstantOp::inferShapes() {
|
||||||
else
|
else
|
||||||
valAttr = valueAttr().cast<DenseElementsAttr>();
|
valAttr = valueAttr().cast<DenseElementsAttr>();
|
||||||
getResult().setType(valAttr.getType());
|
getResult().setType(valAttr.getType());
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||||
|
|
||||||
let methods = [
|
let methods = [
|
||||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||||
"void", "inferShapes">
|
"bool", "inferShapes">
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,11 @@ public:
|
||||||
f.walk([&](mlir::Operation *op) {
|
f.walk([&](mlir::Operation *op) {
|
||||||
if (returnsDynamicShape(op)) {
|
if (returnsDynamicShape(op)) {
|
||||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||||
shape_op.inferShapes();
|
if (!shape_op.inferShapes()) {
|
||||||
|
op->emitError("unable to infer shape of operation without shape "
|
||||||
|
"inference interface");
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
op->emitError("unable to infer shape of operation without shape "
|
op->emitError("unable to infer shape of operation without shape "
|
||||||
"inference interface");
|
"inference interface");
|
||||||
|
|
Loading…
Reference in New Issue