Fix input argument indexing error (#69)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Fixes. Emit error before return in shape inference. * Fix description. * Fix emitted error message. * Fix index name.
This commit is contained in:
parent
83eb15bfae
commit
8532a10614
|
@ -434,10 +434,17 @@ private:
|
||||||
module_.push_back(entryPoint);
|
module_.push_back(entryPoint);
|
||||||
|
|
||||||
// Map graph inputs to entry block arguments.
|
// Map graph inputs to entry block arguments.
|
||||||
for (int i = 0; i < graph.input().size(); ++i)
|
// Counter of un-initialized tensors. This counter is used to index the
|
||||||
|
// entry block arguments.
|
||||||
|
int entryBlockArgIdx = 0;
|
||||||
|
for (int i = 0; i < graph.input().size(); ++i) {
|
||||||
if (!initializedTensors.ContainKey(
|
if (!initializedTensors.ContainKey(
|
||||||
legalize_name(graph.input()[i].name())))
|
legalize_name(graph.input()[i].name()))) {
|
||||||
ImportInputTensorSymbol(graph.input()[i], entryBlock.getArguments()[i]);
|
ImportInputTensorSymbol(
|
||||||
|
graph.input()[i], entryBlock.getArguments()[entryBlockArgIdx]);
|
||||||
|
entryBlockArgIdx++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create a NoneTyped constant to be used for optional operation inputs
|
// Create a NoneTyped constant to be used for optional operation inputs
|
||||||
// which are not used.
|
// which are not used.
|
||||||
|
|
|
@ -514,7 +514,7 @@ bool ONNXAbsOp::inferShapes() {
|
||||||
bool ONNXAddOp::inferShapes() {
|
bool ONNXAddOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
emitError("ONNXAddOp inferShapes failed");
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
|
@ -529,8 +529,10 @@ bool ONNXAddOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXMulOp::inferShapes() {
|
bool ONNXMulOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
@ -543,8 +545,10 @@ bool ONNXMulOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXDivOp::inferShapes() {
|
bool ONNXDivOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
@ -557,8 +561,10 @@ bool ONNXDivOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXSubOp::inferShapes() {
|
bool ONNXSubOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
@ -571,8 +577,10 @@ bool ONNXSubOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXAndOp::inferShapes() {
|
bool ONNXAndOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
@ -585,8 +593,10 @@ bool ONNXAndOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXOrOp::inferShapes() {
|
bool ONNXOrOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
@ -599,8 +609,10 @@ bool ONNXOrOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXXorOp::inferShapes() {
|
bool ONNXXorOp::inferShapes() {
|
||||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||||
!getOperand(1).getType().isa<RankedTensorType>())
|
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||||
|
@ -615,9 +627,11 @@ bool ONNXXorOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXSumOp::inferShapes() {
|
bool ONNXSumOp::inferShapes() {
|
||||||
for (int i = 0; i < getNumOperands(); ++i) {
|
for (int i = 0; i < getNumOperands(); ++i) {
|
||||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
for (int i = 1; i < getNumOperands(); ++i) {
|
for (int i = 1; i < getNumOperands(); ++i) {
|
||||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||||
|
@ -633,9 +647,11 @@ bool ONNXSumOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXMaxOp::inferShapes() {
|
bool ONNXMaxOp::inferShapes() {
|
||||||
for (int i = 0; i < getNumOperands(); ++i) {
|
for (int i = 0; i < getNumOperands(); ++i) {
|
||||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
for (int i = 1; i < getNumOperands(); ++i) {
|
for (int i = 1; i < getNumOperands(); ++i) {
|
||||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||||
|
@ -651,9 +667,11 @@ bool ONNXMaxOp::inferShapes() {
|
||||||
/// shape inference interface.
|
/// shape inference interface.
|
||||||
bool ONNXMinOp::inferShapes() {
|
bool ONNXMinOp::inferShapes() {
|
||||||
for (int i = 0; i < getNumOperands(); ++i) {
|
for (int i = 0; i < getNumOperands(); ++i) {
|
||||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||||
for (int i = 1; i < getNumOperands(); ++i) {
|
for (int i = 1; i < getNumOperands(); ++i) {
|
||||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||||
|
@ -679,8 +697,10 @@ bool ONNXIdentityOp::inferShapes() {
|
||||||
bool ONNXMatMulOp::inferShapes() {
|
bool ONNXMatMulOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!A().getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!B().getType().isa<RankedTensorType>())
|
!B().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
|
@ -819,8 +839,10 @@ bool ONNXGemmOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!A().getType().isa<RankedTensorType>() ||
|
if (!A().getType().isa<RankedTensorType>() ||
|
||||||
!B().getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
(hasBias && !C().getType().isa<RankedTensorType>()))
|
(hasBias && !C().getType().isa<RankedTensorType>())) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||||
|
|
||||||
|
@ -862,8 +884,10 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
!scale().getType().isa<RankedTensorType>() ||
|
!scale().getType().isa<RankedTensorType>() ||
|
||||||
!B().getType().isa<RankedTensorType>() ||
|
!B().getType().isa<RankedTensorType>() ||
|
||||||
!mean().getType().isa<RankedTensorType>() ||
|
!mean().getType().isa<RankedTensorType>() ||
|
||||||
!var().getType().isa<RankedTensorType>())
|
!var().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor(s) not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
||||||
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
||||||
|
@ -915,8 +939,15 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXReshapeOp::inferShapes() {
|
bool ONNXReshapeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape tensor is specified.
|
// Cannot infer shape if no shape tensor is specified.
|
||||||
if (!shape().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input data tensor not ranked");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!shape().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Shape tensor not ranked");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
||||||
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
||||||
|
@ -991,8 +1022,10 @@ bool ONNXReshapeOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXTransposeOp::inferShapes() {
|
bool ONNXTransposeOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!data().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Naive transposition which handles the default case of
|
// Naive transposition which handles the default case of
|
||||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||||
|
@ -1019,7 +1052,7 @@ bool ONNXTransposeOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXReduceMaxOp::inferShapes() {
|
bool ONNXReduceMaxOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1034,7 +1067,7 @@ bool ONNXReduceMaxOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXReduceMinOp::inferShapes() {
|
bool ONNXReduceMinOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1049,7 +1082,7 @@ bool ONNXReduceMinOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXReduceProdOp::inferShapes() {
|
bool ONNXReduceProdOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1064,7 +1097,7 @@ bool ONNXReduceProdOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXReduceSumOp::inferShapes() {
|
bool ONNXReduceSumOp::inferShapes() {
|
||||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||||
emitError("Shape tensor not ranked");
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1097,8 +1130,10 @@ bool ONNXConvOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>() ||
|
if (!X().getType().isa<RankedTensorType>() ||
|
||||||
!W().getType().isa<RankedTensorType>() ||
|
!W().getType().isa<RankedTensorType>() ||
|
||||||
(hasBias && !B().getType().isa<RankedTensorType>()))
|
(hasBias && !B().getType().isa<RankedTensorType>())) {
|
||||||
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
auto xShape = xTy.getShape();
|
auto xShape = xTy.getShape();
|
||||||
|
@ -1210,8 +1245,10 @@ bool ONNXConvOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXAveragePoolOp::inferShapes() {
|
bool ONNXAveragePoolOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>())
|
if (!X().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Get shape of input.
|
// Get shape of input.
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
|
@ -1255,8 +1292,10 @@ bool ONNXAveragePoolOp::inferShapes() {
|
||||||
|
|
||||||
bool ONNXMaxPoolSingleOutOp::inferShapes() {
|
bool ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||||
// Cannot infer shape if no shape exists.
|
// Cannot infer shape if no shape exists.
|
||||||
if (!X().getType().isa<RankedTensorType>())
|
if (!X().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Get shape of input.
|
// Get shape of input.
|
||||||
auto xTy = X().getType().cast<RankedTensorType>();
|
auto xTy = X().getType().cast<RankedTensorType>();
|
||||||
|
@ -1364,8 +1403,10 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
||||||
// Unsqueeze
|
// Unsqueeze
|
||||||
|
|
||||||
bool ONNXUnsqueezeOp::inferShapes() {
|
bool ONNXUnsqueezeOp::inferShapes() {
|
||||||
if (!data().getType().isa<RankedTensorType>())
|
if (!data().getType().isa<RankedTensorType>()) {
|
||||||
|
emitError("Input tensor not ranked");
|
||||||
return false;
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto operandTy = data().getType().cast<RankedTensorType>();
|
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||||
int inRank = operandTy.getRank();
|
int inRank = operandTy.getRank();
|
||||||
|
|
|
@ -38,6 +38,7 @@
|
||||||
#include "mlir/Transforms/Passes.h"
|
#include "mlir/Transforms/Passes.h"
|
||||||
|
|
||||||
enum EmissionTargetType {
|
enum EmissionTargetType {
|
||||||
|
EmitONNXBasic,
|
||||||
EmitONNXIR,
|
EmitONNXIR,
|
||||||
EmitMLIR,
|
EmitMLIR,
|
||||||
EmitLLVMIR,
|
EmitLLVMIR,
|
||||||
|
|
|
@ -37,7 +37,7 @@ public:
|
||||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||||
if (!shape_op.inferShapes()) {
|
if (!shape_op.inferShapes()) {
|
||||||
op->emitError("unable to infer shape of operation without shape "
|
op->emitError("unable to infer shape of operation without shape "
|
||||||
"inference interface");
|
"inference method");
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -23,6 +23,9 @@ int main(int argc, char *argv[]) {
|
||||||
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
||||||
llvm::cl::desc("Choose target to emit:"),
|
llvm::cl::desc("Choose target to emit:"),
|
||||||
llvm::cl::values(
|
llvm::cl::values(
|
||||||
|
clEnumVal(EmitONNXBasic,
|
||||||
|
"Ingest ONNX and emit the basic ONNX operations without"
|
||||||
|
"inferred shapes."),
|
||||||
clEnumVal(EmitONNXIR,
|
clEnumVal(EmitONNXIR,
|
||||||
"Ingest ONNX and emit corresponding ONNX dialect."),
|
"Ingest ONNX and emit corresponding ONNX dialect."),
|
||||||
clEnumVal(EmitMLIR,
|
clEnumVal(EmitMLIR,
|
||||||
|
@ -41,6 +44,7 @@ int main(int argc, char *argv[]) {
|
||||||
processInputFile(inputFilename, emissionTarget, context, module);
|
processInputFile(inputFilename, emissionTarget, context, module);
|
||||||
|
|
||||||
mlir::PassManager pm(&context);
|
mlir::PassManager pm(&context);
|
||||||
|
if (emissionTarget >= EmitONNXIR)
|
||||||
addONNXToMLIRPasses(pm);
|
addONNXToMLIRPasses(pm);
|
||||||
|
|
||||||
if (emissionTarget >= EmitMLIR) {
|
if (emissionTarget >= EmitMLIR) {
|
||||||
|
|
Loading…
Reference in New Issue