diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 02990c4..cb79974 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -434,10 +434,17 @@ private: module_.push_back(entryPoint); // Map graph inputs to entry block arguments. - for (int i = 0; i < graph.input().size(); ++i) + // Counter of un-initialized tensors. This counter is used to index the + // entry block arguments. + int entryBlockArgIdx = 0; + for (int i = 0; i < graph.input().size(); ++i) { if (!initializedTensors.ContainKey( - legalize_name(graph.input()[i].name()))) - ImportInputTensorSymbol(graph.input()[i], entryBlock.getArguments()[i]); + legalize_name(graph.input()[i].name()))) { + ImportInputTensorSymbol( + graph.input()[i], entryBlock.getArguments()[entryBlockArgIdx]); + entryBlockArgIdx++; + } + } // Create a NoneTyped constant to be used for optional operation inputs // which are not used. diff --git a/src/Dialect/ONNX/ONNXOps.cpp b/src/Dialect/ONNX/ONNXOps.cpp index 562d736..c52b6a8 100644 --- a/src/Dialect/ONNX/ONNXOps.cpp +++ b/src/Dialect/ONNX/ONNXOps.cpp @@ -514,7 +514,7 @@ bool ONNXAbsOp::inferShapes() { bool ONNXAddOp::inferShapes() { if (!getOperand(0).getType().isa() || !getOperand(1).getType().isa()) { - emitError("ONNXAddOp inferShapes failed"); + emitError("Input tensor(s) not ranked"); return false; } auto lhsTy = getOperand(0).getType().cast(); @@ -529,8 +529,10 @@ bool ONNXAddOp::inferShapes() { /// shape inference interface. bool ONNXMulOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + !getOperand(1).getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); @@ -543,8 +545,10 @@ bool ONNXMulOp::inferShapes() { /// shape inference interface. bool ONNXDivOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + !getOperand(1).getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); @@ -557,8 +561,10 @@ bool ONNXDivOp::inferShapes() { /// shape inference interface. bool ONNXSubOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + !getOperand(1).getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); @@ -571,8 +577,10 @@ bool ONNXSubOp::inferShapes() { /// shape inference interface. bool ONNXAndOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + !getOperand(1).getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); @@ -585,8 +593,10 @@ bool ONNXAndOp::inferShapes() { /// shape inference interface. bool ONNXOrOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + !getOperand(1).getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); @@ -599,8 +609,10 @@ bool ONNXOrOp::inferShapes() { /// shape inference interface. bool ONNXXorOp::inferShapes() { if (!getOperand(0).getType().isa() || - !getOperand(1).getType().isa()) + !getOperand(1).getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = getOperand(0).getType().cast(); auto rhsTy = getOperand(1).getType().cast(); getResult().setType(getBroadcastedType(lhsTy, rhsTy)); @@ -615,8 +627,10 @@ bool ONNXXorOp::inferShapes() { /// shape inference interface. bool ONNXSumOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i).getType().cast()) + if (!getOperand(i).getType().cast()) { + emitError("Input tensor(s) not ranked"); return false; + } } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -633,8 +647,10 @@ bool ONNXSumOp::inferShapes() { /// shape inference interface. bool ONNXMaxOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i).getType().cast()) + if (!getOperand(i).getType().cast()) { + emitError("Input tensor(s) not ranked"); return false; + } } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -651,8 +667,10 @@ bool ONNXMaxOp::inferShapes() { /// shape inference interface. bool ONNXMinOp::inferShapes() { for (int i = 0; i < getNumOperands(); ++i) { - if (!getOperand(i).getType().cast()) + if (!getOperand(i).getType().cast()) { + emitError("Input tensor(s) not ranked"); return false; + } } Type resultTy = getOperand(0).getType().cast(); for (int i = 1; i < getNumOperands(); ++i) { @@ -679,8 +697,10 @@ bool ONNXIdentityOp::inferShapes() { bool ONNXMatMulOp::inferShapes() { // Cannot infer shape if no shape exists. if (!A().getType().isa() || - !B().getType().isa()) + !B().getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); @@ -819,8 +839,10 @@ bool ONNXGemmOp::inferShapes() { // Cannot infer shape if no shape exists. if (!A().getType().isa() || !B().getType().isa() || - (hasBias && !C().getType().isa())) + (hasBias && !C().getType().isa())) { + emitError("Input tensor(s) not ranked"); return false; + } auto lhsTy = A().getType().cast(); auto rhsTy = B().getType().cast(); @@ -862,8 +884,10 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() { !scale().getType().isa() || !B().getType().isa() || !mean().getType().isa() || - !var().getType().isa()) + !var().getType().isa()) { + emitError("Input tensor(s) not ranked"); return false; + } auto inputTensorTy = X().getType().cast(); auto scaleTensorTy = scale().getType().cast(); @@ -915,8 +939,15 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() { bool ONNXReshapeOp::inferShapes() { // Cannot infer shape if no shape tensor is specified. - if (!shape().getType().isa()) + if (!data().getType().isa()) { + emitError("Input data tensor not ranked"); + return false; + } + + if (!shape().getType().isa()) { emitError("Shape tensor not ranked"); + return false; + } auto inputTensorTy = data().getType().cast(); auto shapeTensorTy = shape().getType().cast(); @@ -991,8 +1022,10 @@ bool ONNXReshapeOp::inferShapes() { bool ONNXTransposeOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!data().getType().isa()) + if (!data().getType().isa()) { + emitError("Input tensor not ranked"); return false; + } // Naive transposition which handles the default case of // reversing the shape of the tensor (similar to numpy.transpose). @@ -1019,7 +1052,7 @@ bool ONNXTransposeOp::inferShapes() { bool ONNXReduceMaxOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked"); + emitError("Input tensor not ranked"); return false; } @@ -1034,7 +1067,7 @@ bool ONNXReduceMaxOp::inferShapes() { bool ONNXReduceMinOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked"); + emitError("Input tensor not ranked"); return false; } @@ -1049,7 +1082,7 @@ bool ONNXReduceMinOp::inferShapes() { bool ONNXReduceProdOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked"); + emitError("Input tensor not ranked"); return false; } @@ -1064,7 +1097,7 @@ bool ONNXReduceProdOp::inferShapes() { bool ONNXReduceSumOp::inferShapes() { if (!getOperand().getType().isa()) { - emitError("Shape tensor not ranked"); + emitError("Input tensor not ranked"); return false; } @@ -1097,8 +1130,10 @@ bool ONNXConvOp::inferShapes() { // Cannot infer shape if no shape exists. if (!X().getType().isa() || !W().getType().isa() || - (hasBias && !B().getType().isa())) + (hasBias && !B().getType().isa())) { + emitError("Input tensor not ranked"); return false; + } auto xTy = X().getType().cast(); auto xShape = xTy.getShape(); @@ -1210,8 +1245,10 @@ bool ONNXConvOp::inferShapes() { bool ONNXAveragePoolOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!X().getType().isa()) + if (!X().getType().isa()) { + emitError("Input tensor not ranked"); return false; + } // Get shape of input. auto xTy = X().getType().cast(); @@ -1255,8 +1292,10 @@ bool ONNXAveragePoolOp::inferShapes() { bool ONNXMaxPoolSingleOutOp::inferShapes() { // Cannot infer shape if no shape exists. - if (!X().getType().isa()) + if (!X().getType().isa()) { + emitError("Input tensor not ranked"); return false; + } // Get shape of input. auto xTy = X().getType().cast(); @@ -1364,8 +1403,10 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state, // Unsqueeze bool ONNXUnsqueezeOp::inferShapes() { - if (!data().getType().isa()) + if (!data().getType().isa()) { + emitError("Input tensor not ranked"); return false; + } auto operandTy = data().getType().cast(); int inRank = operandTy.getRank(); diff --git a/src/MainUtils.hpp b/src/MainUtils.hpp index eef6521..143f6fc 100644 --- a/src/MainUtils.hpp +++ b/src/MainUtils.hpp @@ -38,6 +38,7 @@ #include "mlir/Transforms/Passes.h" enum EmissionTargetType { + EmitONNXBasic, EmitONNXIR, EmitMLIR, EmitLLVMIR, diff --git a/src/Transform/ONNX/ShapeInferencePass.cpp b/src/Transform/ONNX/ShapeInferencePass.cpp index 1bf9855..b4385c2 100644 --- a/src/Transform/ONNX/ShapeInferencePass.cpp +++ b/src/Transform/ONNX/ShapeInferencePass.cpp @@ -37,7 +37,7 @@ public: if (auto shape_op = dyn_cast(op)) { if (!shape_op.inferShapes()) { op->emitError("unable to infer shape of operation without shape " - "inference interface"); + "inference method"); return signalPassFailure(); } } else { diff --git a/src/main.cpp b/src/main.cpp index 1dacbff..8bbdef7 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -23,6 +23,9 @@ int main(int argc, char *argv[]) { llvm::cl::opt emissionTarget( llvm::cl::desc("Choose target to emit:"), llvm::cl::values( + clEnumVal(EmitONNXBasic, + "Ingest ONNX and emit the basic ONNX operations without" + "inferred shapes."), clEnumVal(EmitONNXIR, "Ingest ONNX and emit corresponding ONNX dialect."), clEnumVal(EmitMLIR, @@ -41,7 +44,8 @@ int main(int argc, char *argv[]) { processInputFile(inputFilename, emissionTarget, context, module); mlir::PassManager pm(&context); - addONNXToMLIRPasses(pm); + if (emissionTarget >= EmitONNXIR) + addONNXToMLIRPasses(pm); if (emissionTarget >= EmitMLIR) { addONNXToKrnlPasses(pm);