Fix input argument indexing error (#69)

* Reorganize main function.

* Follow review comments.

* Emit constants are globals in Krnl and LLVM dialects.

* Fixes. Emit error before return in shape inference.

* Fix description.

* Fix emitted error message.

* Fix index name.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-04-06 11:35:17 -04:00 committed by GitHub
parent 83eb15bfae
commit 8532a10614
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 81 additions and 28 deletions

View File

@ -434,10 +434,17 @@ private:
module_.push_back(entryPoint);
// Map graph inputs to entry block arguments.
for (int i = 0; i < graph.input().size(); ++i)
// Counter of un-initialized tensors. This counter is used to index the
// entry block arguments.
int entryBlockArgIdx = 0;
for (int i = 0; i < graph.input().size(); ++i) {
if (!initializedTensors.ContainKey(
legalize_name(graph.input()[i].name())))
ImportInputTensorSymbol(graph.input()[i], entryBlock.getArguments()[i]);
legalize_name(graph.input()[i].name()))) {
ImportInputTensorSymbol(
graph.input()[i], entryBlock.getArguments()[entryBlockArgIdx]);
entryBlockArgIdx++;
}
}
// Create a NoneTyped constant to be used for optional operation inputs
// which are not used.

View File

@ -514,7 +514,7 @@ bool ONNXAbsOp::inferShapes() {
bool ONNXAddOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("ONNXAddOp inferShapes failed");
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
@ -529,8 +529,10 @@ bool ONNXAddOp::inferShapes() {
/// shape inference interface.
bool ONNXMulOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -543,8 +545,10 @@ bool ONNXMulOp::inferShapes() {
/// shape inference interface.
bool ONNXDivOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -557,8 +561,10 @@ bool ONNXDivOp::inferShapes() {
/// shape inference interface.
bool ONNXSubOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -571,8 +577,10 @@ bool ONNXSubOp::inferShapes() {
/// shape inference interface.
bool ONNXAndOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -585,8 +593,10 @@ bool ONNXAndOp::inferShapes() {
/// shape inference interface.
bool ONNXOrOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -599,8 +609,10 @@ bool ONNXOrOp::inferShapes() {
/// shape inference interface.
bool ONNXXorOp::inferShapes() {
if (!getOperand(0).getType().isa<RankedTensorType>() ||
!getOperand(1).getType().isa<RankedTensorType>())
!getOperand(1).getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
@ -615,9 +627,11 @@ bool ONNXXorOp::inferShapes() {
/// shape inference interface.
bool ONNXSumOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>())
if (!getOperand(i).getType().cast<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
}
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
@ -633,9 +647,11 @@ bool ONNXSumOp::inferShapes() {
/// shape inference interface.
bool ONNXMaxOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>())
if (!getOperand(i).getType().cast<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
}
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
@ -651,9 +667,11 @@ bool ONNXMaxOp::inferShapes() {
/// shape inference interface.
bool ONNXMinOp::inferShapes() {
for (int i = 0; i < getNumOperands(); ++i) {
if (!getOperand(i).getType().cast<RankedTensorType>())
if (!getOperand(i).getType().cast<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
}
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
for (int i = 1; i < getNumOperands(); ++i) {
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
@ -679,8 +697,10 @@ bool ONNXIdentityOp::inferShapes() {
bool ONNXMatMulOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>())
!B().getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
@ -819,8 +839,10 @@ bool ONNXGemmOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!A().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
(hasBias && !C().getType().isa<RankedTensorType>()))
(hasBias && !C().getType().isa<RankedTensorType>())) {
emitError("Input tensor(s) not ranked");
return false;
}
auto lhsTy = A().getType().cast<RankedTensorType>();
auto rhsTy = B().getType().cast<RankedTensorType>();
@ -862,8 +884,10 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
!scale().getType().isa<RankedTensorType>() ||
!B().getType().isa<RankedTensorType>() ||
!mean().getType().isa<RankedTensorType>() ||
!var().getType().isa<RankedTensorType>())
!var().getType().isa<RankedTensorType>()) {
emitError("Input tensor(s) not ranked");
return false;
}
auto inputTensorTy = X().getType().cast<RankedTensorType>();
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
@ -915,8 +939,15 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
bool ONNXReshapeOp::inferShapes() {
// Cannot infer shape if no shape tensor is specified.
if (!shape().getType().isa<RankedTensorType>())
if (!data().getType().isa<RankedTensorType>()) {
emitError("Input data tensor not ranked");
return false;
}
if (!shape().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
return false;
}
auto inputTensorTy = data().getType().cast<RankedTensorType>();
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
@ -991,8 +1022,10 @@ bool ONNXReshapeOp::inferShapes() {
bool ONNXTransposeOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!data().getType().isa<RankedTensorType>())
if (!data().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false;
}
// Naive transposition which handles the default case of
// reversing the shape of the tensor (similar to numpy.transpose).
@ -1019,7 +1052,7 @@ bool ONNXTransposeOp::inferShapes() {
bool ONNXReduceMaxOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
emitError("Input tensor not ranked");
return false;
}
@ -1034,7 +1067,7 @@ bool ONNXReduceMaxOp::inferShapes() {
bool ONNXReduceMinOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
emitError("Input tensor not ranked");
return false;
}
@ -1049,7 +1082,7 @@ bool ONNXReduceMinOp::inferShapes() {
bool ONNXReduceProdOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
emitError("Input tensor not ranked");
return false;
}
@ -1064,7 +1097,7 @@ bool ONNXReduceProdOp::inferShapes() {
bool ONNXReduceSumOp::inferShapes() {
if (!getOperand().getType().isa<RankedTensorType>()) {
emitError("Shape tensor not ranked");
emitError("Input tensor not ranked");
return false;
}
@ -1097,8 +1130,10 @@ bool ONNXConvOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>() ||
!W().getType().isa<RankedTensorType>() ||
(hasBias && !B().getType().isa<RankedTensorType>()))
(hasBias && !B().getType().isa<RankedTensorType>())) {
emitError("Input tensor not ranked");
return false;
}
auto xTy = X().getType().cast<RankedTensorType>();
auto xShape = xTy.getShape();
@ -1210,8 +1245,10 @@ bool ONNXConvOp::inferShapes() {
bool ONNXAveragePoolOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>())
if (!X().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false;
}
// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
@ -1255,8 +1292,10 @@ bool ONNXAveragePoolOp::inferShapes() {
bool ONNXMaxPoolSingleOutOp::inferShapes() {
// Cannot infer shape if no shape exists.
if (!X().getType().isa<RankedTensorType>())
if (!X().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false;
}
// Get shape of input.
auto xTy = X().getType().cast<RankedTensorType>();
@ -1364,8 +1403,10 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
// Unsqueeze
bool ONNXUnsqueezeOp::inferShapes() {
if (!data().getType().isa<RankedTensorType>())
if (!data().getType().isa<RankedTensorType>()) {
emitError("Input tensor not ranked");
return false;
}
auto operandTy = data().getType().cast<RankedTensorType>();
int inRank = operandTy.getRank();

View File

@ -38,6 +38,7 @@
#include "mlir/Transforms/Passes.h"
enum EmissionTargetType {
EmitONNXBasic,
EmitONNXIR,
EmitMLIR,
EmitLLVMIR,

View File

@ -37,7 +37,7 @@ public:
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
if (!shape_op.inferShapes()) {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
"inference method");
return signalPassFailure();
}
} else {

View File

@ -23,6 +23,9 @@ int main(int argc, char *argv[]) {
llvm::cl::opt<EmissionTargetType> emissionTarget(
llvm::cl::desc("Choose target to emit:"),
llvm::cl::values(
clEnumVal(EmitONNXBasic,
"Ingest ONNX and emit the basic ONNX operations without"
"inferred shapes."),
clEnumVal(EmitONNXIR,
"Ingest ONNX and emit corresponding ONNX dialect."),
clEnumVal(EmitMLIR,
@ -41,6 +44,7 @@ int main(int argc, char *argv[]) {
processInputFile(inputFilename, emissionTarget, context, module);
mlir::PassManager pm(&context);
if (emissionTarget >= EmitONNXIR)
addONNXToMLIRPasses(pm);
if (emissionTarget >= EmitMLIR) {