Fix input argument indexing error (#69)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Fixes. Emit error before return in shape inference. * Fix description. * Fix emitted error message. * Fix index name.
This commit is contained in:
		
							parent
							
								
									83eb15bfae
								
							
						
					
					
						commit
						8532a10614
					
				|  | @ -434,10 +434,17 @@ private: | |||
|     module_.push_back(entryPoint); | ||||
| 
 | ||||
|     // Map graph inputs to entry block arguments.
 | ||||
|     for (int i = 0; i < graph.input().size(); ++i) | ||||
|     // Counter of un-initialized tensors. This counter is used to index the
 | ||||
|     // entry block arguments.
 | ||||
|     int entryBlockArgIdx = 0; | ||||
|     for (int i = 0; i < graph.input().size(); ++i) { | ||||
|       if (!initializedTensors.ContainKey( | ||||
|               legalize_name(graph.input()[i].name()))) | ||||
|         ImportInputTensorSymbol(graph.input()[i], entryBlock.getArguments()[i]); | ||||
|               legalize_name(graph.input()[i].name()))) { | ||||
|         ImportInputTensorSymbol( | ||||
|             graph.input()[i], entryBlock.getArguments()[entryBlockArgIdx]); | ||||
|         entryBlockArgIdx++; | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // Create a NoneTyped constant to be used for optional operation inputs
 | ||||
|     // which are not used.
 | ||||
|  |  | |||
|  | @ -514,7 +514,7 @@ bool ONNXAbsOp::inferShapes() { | |||
| bool ONNXAddOp::inferShapes() { | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) { | ||||
|     emitError("ONNXAddOp inferShapes failed"); | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|  | @ -529,8 +529,10 @@ bool ONNXAddOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXMulOp::inferShapes() { | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
|   getResult().setType(getBroadcastedType(lhsTy, rhsTy)); | ||||
|  | @ -543,8 +545,10 @@ bool ONNXMulOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXDivOp::inferShapes() { | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
|   getResult().setType(getBroadcastedType(lhsTy, rhsTy)); | ||||
|  | @ -557,8 +561,10 @@ bool ONNXDivOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXSubOp::inferShapes() { | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
|   getResult().setType(getBroadcastedType(lhsTy, rhsTy)); | ||||
|  | @ -571,8 +577,10 @@ bool ONNXSubOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXAndOp::inferShapes() { | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
|   getResult().setType(getBroadcastedType(lhsTy, rhsTy)); | ||||
|  | @ -585,8 +593,10 @@ bool ONNXAndOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXOrOp::inferShapes() { | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
|   getResult().setType(getBroadcastedType(lhsTy, rhsTy)); | ||||
|  | @ -599,8 +609,10 @@ bool ONNXOrOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXXorOp::inferShapes() { | ||||
|   if (!getOperand(0).getType().isa<RankedTensorType>() || | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) | ||||
|       !getOperand(1).getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); | ||||
|   getResult().setType(getBroadcastedType(lhsTy, rhsTy)); | ||||
|  | @ -615,8 +627,10 @@ bool ONNXXorOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXSumOp::inferShapes() { | ||||
|   for (int i = 0; i < getNumOperands(); ++i) { | ||||
|     if (!getOperand(i).getType().cast<RankedTensorType>()) | ||||
|     if (!getOperand(i).getType().cast<RankedTensorType>()) { | ||||
|       emitError("Input tensor(s) not ranked"); | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|   Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   for (int i = 1; i < getNumOperands(); ++i) { | ||||
|  | @ -633,8 +647,10 @@ bool ONNXSumOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXMaxOp::inferShapes() { | ||||
|   for (int i = 0; i < getNumOperands(); ++i) { | ||||
|     if (!getOperand(i).getType().cast<RankedTensorType>()) | ||||
|     if (!getOperand(i).getType().cast<RankedTensorType>()) { | ||||
|       emitError("Input tensor(s) not ranked"); | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|   Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   for (int i = 1; i < getNumOperands(); ++i) { | ||||
|  | @ -651,8 +667,10 @@ bool ONNXMaxOp::inferShapes() { | |||
| /// shape inference interface.
 | ||||
| bool ONNXMinOp::inferShapes() { | ||||
|   for (int i = 0; i < getNumOperands(); ++i) { | ||||
|     if (!getOperand(i).getType().cast<RankedTensorType>()) | ||||
|     if (!getOperand(i).getType().cast<RankedTensorType>()) { | ||||
|       emitError("Input tensor(s) not ranked"); | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
|   Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); | ||||
|   for (int i = 1; i < getNumOperands(); ++i) { | ||||
|  | @ -679,8 +697,10 @@ bool ONNXIdentityOp::inferShapes() { | |||
| bool ONNXMatMulOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!A().getType().isa<RankedTensorType>() || | ||||
|       !B().getType().isa<RankedTensorType>()) | ||||
|       !B().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   auto lhsTy = A().getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = B().getType().cast<RankedTensorType>(); | ||||
|  | @ -819,8 +839,10 @@ bool ONNXGemmOp::inferShapes() { | |||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!A().getType().isa<RankedTensorType>() || | ||||
|       !B().getType().isa<RankedTensorType>() || | ||||
|       (hasBias && !C().getType().isa<RankedTensorType>())) | ||||
|       (hasBias && !C().getType().isa<RankedTensorType>())) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
|   auto lhsTy = A().getType().cast<RankedTensorType>(); | ||||
|   auto rhsTy = B().getType().cast<RankedTensorType>(); | ||||
| 
 | ||||
|  | @ -862,8 +884,10 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() { | |||
|       !scale().getType().isa<RankedTensorType>() || | ||||
|       !B().getType().isa<RankedTensorType>() || | ||||
|       !mean().getType().isa<RankedTensorType>() || | ||||
|       !var().getType().isa<RankedTensorType>()) | ||||
|       !var().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor(s) not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   auto inputTensorTy = X().getType().cast<RankedTensorType>(); | ||||
|   auto scaleTensorTy = scale().getType().cast<RankedTensorType>(); | ||||
|  | @ -915,8 +939,15 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXReshapeOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape tensor is specified.
 | ||||
|   if (!shape().getType().isa<RankedTensorType>()) | ||||
|   if (!data().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input data tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   if (!shape().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Shape tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   auto inputTensorTy = data().getType().cast<RankedTensorType>(); | ||||
|   auto shapeTensorTy = shape().getType().cast<RankedTensorType>(); | ||||
|  | @ -991,8 +1022,10 @@ bool ONNXReshapeOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXTransposeOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!data().getType().isa<RankedTensorType>()) | ||||
|   if (!data().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   // Naive transposition which handles the default case of
 | ||||
|   // reversing the shape of the tensor (similar to numpy.transpose).
 | ||||
|  | @ -1019,7 +1052,7 @@ bool ONNXTransposeOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXReduceMaxOp::inferShapes() { | ||||
|   if (!getOperand().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Shape tensor not ranked"); | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|  | @ -1034,7 +1067,7 @@ bool ONNXReduceMaxOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXReduceMinOp::inferShapes() { | ||||
|   if (!getOperand().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Shape tensor not ranked"); | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|  | @ -1049,7 +1082,7 @@ bool ONNXReduceMinOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXReduceProdOp::inferShapes() { | ||||
|   if (!getOperand().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Shape tensor not ranked"); | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|  | @ -1064,7 +1097,7 @@ bool ONNXReduceProdOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXReduceSumOp::inferShapes() { | ||||
|   if (!getOperand().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Shape tensor not ranked"); | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|  | @ -1097,8 +1130,10 @@ bool ONNXConvOp::inferShapes() { | |||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!X().getType().isa<RankedTensorType>() || | ||||
|       !W().getType().isa<RankedTensorType>() || | ||||
|       (hasBias && !B().getType().isa<RankedTensorType>())) | ||||
|       (hasBias && !B().getType().isa<RankedTensorType>())) { | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   auto xTy = X().getType().cast<RankedTensorType>(); | ||||
|   auto xShape = xTy.getShape(); | ||||
|  | @ -1210,8 +1245,10 @@ bool ONNXConvOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXAveragePoolOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!X().getType().isa<RankedTensorType>()) | ||||
|   if (!X().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   // Get shape of input.
 | ||||
|   auto xTy = X().getType().cast<RankedTensorType>(); | ||||
|  | @ -1255,8 +1292,10 @@ bool ONNXAveragePoolOp::inferShapes() { | |||
| 
 | ||||
| bool ONNXMaxPoolSingleOutOp::inferShapes() { | ||||
|   // Cannot infer shape if no shape exists.
 | ||||
|   if (!X().getType().isa<RankedTensorType>()) | ||||
|   if (!X().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   // Get shape of input.
 | ||||
|   auto xTy = X().getType().cast<RankedTensorType>(); | ||||
|  | @ -1364,8 +1403,10 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state, | |||
| // Unsqueeze
 | ||||
| 
 | ||||
| bool ONNXUnsqueezeOp::inferShapes() { | ||||
|   if (!data().getType().isa<RankedTensorType>()) | ||||
|   if (!data().getType().isa<RankedTensorType>()) { | ||||
|     emitError("Input tensor not ranked"); | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   auto operandTy = data().getType().cast<RankedTensorType>(); | ||||
|   int inRank = operandTy.getRank(); | ||||
|  |  | |||
|  | @ -38,6 +38,7 @@ | |||
| #include "mlir/Transforms/Passes.h" | ||||
| 
 | ||||
| enum EmissionTargetType { | ||||
|   EmitONNXBasic, | ||||
|   EmitONNXIR, | ||||
|   EmitMLIR, | ||||
|   EmitLLVMIR, | ||||
|  |  | |||
|  | @ -37,7 +37,7 @@ public: | |||
|         if (auto shape_op = dyn_cast<ShapeInference>(op)) { | ||||
|           if (!shape_op.inferShapes()) { | ||||
|             op->emitError("unable to infer shape of operation without shape " | ||||
|                           "inference interface"); | ||||
|                           "inference method"); | ||||
|             return signalPassFailure(); | ||||
|           } | ||||
|         } else { | ||||
|  |  | |||
|  | @ -23,6 +23,9 @@ int main(int argc, char *argv[]) { | |||
|   llvm::cl::opt<EmissionTargetType> emissionTarget( | ||||
|       llvm::cl::desc("Choose target to emit:"), | ||||
|       llvm::cl::values( | ||||
|           clEnumVal(EmitONNXBasic, | ||||
|                     "Ingest ONNX and emit the basic ONNX operations without" | ||||
|                     "inferred shapes."), | ||||
|           clEnumVal(EmitONNXIR, | ||||
|                     "Ingest ONNX and emit corresponding ONNX dialect."), | ||||
|           clEnumVal(EmitMLIR, | ||||
|  | @ -41,7 +44,8 @@ int main(int argc, char *argv[]) { | |||
|   processInputFile(inputFilename, emissionTarget, context, module); | ||||
| 
 | ||||
|   mlir::PassManager pm(&context); | ||||
|   addONNXToMLIRPasses(pm); | ||||
|   if (emissionTarget >= EmitONNXIR) | ||||
|     addONNXToMLIRPasses(pm); | ||||
| 
 | ||||
|   if (emissionTarget >= EmitMLIR) { | ||||
|     addONNXToKrnlPasses(pm); | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue