Fix input argument indexing error (#69)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Fixes. Emit error before return in shape inference.

* Fix description.

* Fix emitted error message.

* Fix index name.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-04-06 11:35:17 -04:00 committed by GitHub
parent 83eb15bfae
commit 8532a10614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 81 additions and 28 deletions

View File

@ -434,10 +434,17 @@ private:
module_.push_back(entryPoint); module_.push_back(entryPoint);
// Map graph inputs to entry block arguments. // Map graph inputs to entry block arguments.
for (int i = 0; i < graph.input().size(); ++i) // Counter of un-initialized tensors. This counter is used to index the
// entry block arguments.
int entryBlockArgIdx = 0;
for (int i = 0; i < graph.input().size(); ++i) {
if (!initializedTensors.ContainKey( if (!initializedTensors.ContainKey(
legalize_name(graph.input()[i].name()))) legalize_name(graph.input()[i].name()))) {
ImportInputTensorSymbol(graph.input()[i], entryBlock.getArguments()[i]); ImportInputTensorSymbol(
graph.input()[i], entryBlock.getArguments()[entryBlockArgIdx]);
entryBlockArgIdx++;
}
}
// Create a NoneTyped constant to be used for optional operation inputs // Create a NoneTyped constant to be used for optional operation inputs
// which are not used. // which are not used.

View File

@ -514,7 +514,7 @@ bool ONNXAbsOp::inferShapes() {
bool ONNXAddOp::inferShapes() { bool ONNXAddOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) { !getOperand(1).getType().isa<RankedTensorType>()) {
emitError("ONNXAddOp inferShapes failed"); emitError("Input tensor(s) not ranked");
return false; return false;
} }
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
@ -529,8 +529,10 @@ bool ONNXAddOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXMulOp::inferShapes() { bool ONNXMulOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -543,8 +545,10 @@ bool ONNXMulOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXDivOp::inferShapes() { bool ONNXDivOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -557,8 +561,10 @@ bool ONNXDivOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXSubOp::inferShapes() { bool ONNXSubOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -571,8 +577,10 @@ bool ONNXSubOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXAndOp::inferShapes() { bool ONNXAndOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -585,8 +593,10 @@ bool ONNXAndOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXOrOp::inferShapes() { bool ONNXOrOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -599,8 +609,10 @@ bool ONNXOrOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXXorOp::inferShapes() { bool ONNXXorOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() || if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) !getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>(); auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>(); auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy)); getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -615,8 +627,10 @@ bool ONNXXorOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXSumOp::inferShapes() { bool ONNXSumOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
} }
Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
@ -633,8 +647,10 @@ bool ONNXSumOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXMaxOp::inferShapes() { bool ONNXMaxOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
} }
Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
@ -651,8 +667,10 @@ bool ONNXMaxOp::inferShapes() {
/// shape inference interface. /// shape inference interface.
bool ONNXMinOp::inferShapes() { bool ONNXMinOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) { for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>()) if (!getOperand(i).getType().cast<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
} }
Type resultTy = getOperand(0).getType().cast<RankedTensorType>(); Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) { for (int i = 1; i < getNumOperands(); ++i) {
@ -679,8 +697,10 @@ bool ONNXIdentityOp::inferShapes() {
bool ONNXMatMulOp::inferShapes() { bool ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>()) !B().getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = A().getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
@ -819,8 +839,10 @@ bool ONNXGemmOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() || if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
(hasBias && !C().getType().isa<RankedTensorType>())) (hasBias && !C().getType().isa<RankedTensorType>())) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto lhsTy = A().getType().cast<RankedTensorType>(); auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>(); auto rhsTy = B().getType().cast<RankedTensorType>();
@ -862,8 +884,10 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
!scale().getType().isa<RankedTensorType>() || !scale().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() || !B().getType().isa<RankedTensorType>() ||
!mean().getType().isa<RankedTensorType>() || !mean().getType().isa<RankedTensorType>() ||
!var().getType().isa<RankedTensorType>()) !var().getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false; return false;
}
auto inputTensorTy = X().getType().cast<RankedTensorType>(); auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scaleTensorTy = scale().getType().cast<RankedTensorType>(); auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
@ -915,8 +939,15 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
bool ONNXReshapeOp::inferShapes() { bool ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified. // Cannot infer shape if no shape tensor is specified.
if (!shape().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>()) {
emitError("Input data tensor not ranked");
return false;
}
if (!shape().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Shape tensor not ranked");
return false;
}
auto inputTensorTy = data().getType().cast<RankedTensorType>(); auto inputTensorTy = data().getType().cast<RankedTensorType>();
auto shapeTensorTy = shape().getType().cast<RankedTensorType>(); auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
@ -991,8 +1022,10 @@ bool ONNXReshapeOp::inferShapes() {
bool ONNXTransposeOp::inferShapes() { bool ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false; return false;
}
// Naive transposition which handles the default case of // Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose). // reversing the shape of the tensor (similar to numpy.transpose).
@ -1019,7 +1052,7 @@ bool ONNXTransposeOp::inferShapes() {
bool ONNXReduceMaxOp::inferShapes() { bool ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Input tensor not ranked");
return false; return false;
} }
@ -1034,7 +1067,7 @@ bool ONNXReduceMaxOp::inferShapes() {
bool ONNXReduceMinOp::inferShapes() { bool ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Input tensor not ranked");
return false; return false;
} }
@ -1049,7 +1082,7 @@ bool ONNXReduceMinOp::inferShapes() {
bool ONNXReduceProdOp::inferShapes() { bool ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Input tensor not ranked");
return false; return false;
} }
@ -1064,7 +1097,7 @@ bool ONNXReduceProdOp::inferShapes() {
bool ONNXReduceSumOp::inferShapes() { bool ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) { if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked"); emitError("Input tensor not ranked");
return false; return false;
} }
@ -1097,8 +1130,10 @@ bool ONNXConvOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() || if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>() || !W().getType().isa<RankedTensorType>() ||
(hasBias && !B().getType().isa<RankedTensorType>())) (hasBias && !B().getType().isa<RankedTensorType>())) {
emitError("Input tensor not ranked");
return false; return false;
}
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape(); auto xShape = xTy.getShape();
@ -1210,8 +1245,10 @@ bool ONNXConvOp::inferShapes() {
bool ONNXAveragePoolOp::inferShapes() { bool ONNXAveragePoolOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>()) if (!X().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false; return false;
}
// Get shape of input. // Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
@ -1255,8 +1292,10 @@ bool ONNXAveragePoolOp::inferShapes() {
bool ONNXMaxPoolSingleOutOp::inferShapes() { bool ONNXMaxPoolSingleOutOp::inferShapes() {
// Cannot infer shape if no shape exists. // Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>()) if (!X().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false; return false;
}
// Get shape of input. // Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>(); auto xTy = X().getType().cast<RankedTensorType>();
@ -1364,8 +1403,10 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
// Unsqueeze // Unsqueeze
bool ONNXUnsqueezeOp::inferShapes() { bool ONNXUnsqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>()) if (!data().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false; return false;
}
auto operandTy = data().getType().cast<RankedTensorType>(); auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank(); int inRank = operandTy.getRank();

View File

@ -38,6 +38,7 @@
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
enum EmissionTargetType { enum EmissionTargetType {
EmitONNXBasic,
EmitONNXIR, EmitONNXIR,
EmitMLIR, EmitMLIR,
EmitLLVMIR, EmitLLVMIR,

View File

@ -37,7 +37,7 @@ public:
if (auto shape_op = dyn_cast<ShapeInference>(op)) { if (auto shape_op = dyn_cast<ShapeInference>(op)) {
if (!shape_op.inferShapes()) { if (!shape_op.inferShapes()) {
op->emitError("unable to infer shape of operation without shape " op->emitError("unable to infer shape of operation without shape "
"inference interface"); "inference method");
return signalPassFailure(); return signalPassFailure();
} }
} else { } else {

View File

@ -23,6 +23,9 @@ int main(int argc, char *argv[]) {
llvm::cl::opt<EmissionTargetType> emissionTarget( llvm::cl::opt<EmissionTargetType> emissionTarget(
llvm::cl::desc("Choose target to emit:"), llvm::cl::desc("Choose target to emit:"),
llvm::cl::values( llvm::cl::values(
clEnumVal(EmitONNXBasic,
"Ingest ONNX and emit the basic ONNX operations without"
"inferred shapes."),
clEnumVal(EmitONNXIR, clEnumVal(EmitONNXIR,
"Ingest ONNX and emit corresponding ONNX dialect."), "Ingest ONNX and emit corresponding ONNX dialect."),
clEnumVal(EmitMLIR, clEnumVal(EmitMLIR,
@ -41,7 +44,8 @@ int main(int argc, char *argv[]) {
processInputFile(inputFilename, emissionTarget, context, module); processInputFile(inputFilename, emissionTarget, context, module);
mlir::PassManager pm(&context); mlir::PassManager pm(&context);
addONNXToMLIRPasses(pm); if (emissionTarget >= EmitONNXIR)
addONNXToMLIRPasses(pm);
if (emissionTarget >= EmitMLIR) { if (emissionTarget >= EmitMLIR) {
addONNXToKrnlPasses(pm); addONNXToKrnlPasses(pm);