Fix input argument indexing error (#69)
* Reorganize main function. * Follow review comments. * Emit constants are globals in Krnl and LLVM dialects. * Fixes. Emit error before return in shape inference. * Fix description. * Fix emitted error message. * Fix index name.
This commit is contained in:
parent
83eb15bfae
commit
8532a10614
|
@ -434,10 +434,17 @@ private:
|
|||
module_.push_back(entryPoint);
|
||||
|
||||
// Map graph inputs to entry block arguments.
|
||||
for (int i = 0; i < graph.input().size(); ++i)
|
||||
// Counter of un-initialized tensors. This counter is used to index the
|
||||
// entry block arguments.
|
||||
int entryBlockArgIdx = 0;
|
||||
for (int i = 0; i < graph.input().size(); ++i) {
|
||||
if (!initializedTensors.ContainKey(
|
||||
legalize_name(graph.input()[i].name())))
|
||||
ImportInputTensorSymbol(graph.input()[i], entryBlock.getArguments()[i]);
|
||||
legalize_name(graph.input()[i].name()))) {
|
||||
ImportInputTensorSymbol(
|
||||
graph.input()[i], entryBlock.getArguments()[entryBlockArgIdx]);
|
||||
entryBlockArgIdx++;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a NoneTyped constant to be used for optional operation inputs
|
||||
// which are not used.
|
||||
|
|
|
@ -514,7 +514,7 @@ bool ONNXAbsOp::inferShapes() {
|
|||
bool ONNXAddOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("ONNXAddOp inferShapes failed");
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
|
@ -529,8 +529,10 @@ bool ONNXAddOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXMulOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
|
@ -543,8 +545,10 @@ bool ONNXMulOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXDivOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
|
@ -557,8 +561,10 @@ bool ONNXDivOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXSubOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
|
@ -571,8 +577,10 @@ bool ONNXSubOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXAndOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
|
@ -585,8 +593,10 @@ bool ONNXAndOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXOrOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
|
@ -599,8 +609,10 @@ bool ONNXOrOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXXorOp::inferShapes() {
|
||||
if (!getOperand(0).getType().isa<RankedTensorType>() ||
|
||||
!getOperand(1).getType().isa<RankedTensorType>())
|
||||
!getOperand(1).getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto rhsTy = getOperand(1).getType().cast<RankedTensorType>();
|
||||
getResult().setType(getBroadcastedType(lhsTy, rhsTy));
|
||||
|
@ -615,9 +627,11 @@ bool ONNXXorOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXSumOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||
|
@ -633,9 +647,11 @@ bool ONNXSumOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXMaxOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||
|
@ -651,9 +667,11 @@ bool ONNXMaxOp::inferShapes() {
|
|||
/// shape inference interface.
|
||||
bool ONNXMinOp::inferShapes() {
|
||||
for (int i = 0; i < getNumOperands(); ++i) {
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>())
|
||||
if (!getOperand(i).getType().cast<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
Type resultTy = getOperand(0).getType().cast<RankedTensorType>();
|
||||
for (int i = 1; i < getNumOperands(); ++i) {
|
||||
Type nextTy = getOperand(i).getType().cast<RankedTensorType>();
|
||||
|
@ -679,8 +697,10 @@ bool ONNXIdentityOp::inferShapes() {
|
|||
bool ONNXMatMulOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!A().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>())
|
||||
!B().getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||
|
@ -819,8 +839,10 @@ bool ONNXGemmOp::inferShapes() {
|
|||
// Cannot infer shape if no shape exists.
|
||||
if (!A().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>() ||
|
||||
(hasBias && !C().getType().isa<RankedTensorType>()))
|
||||
(hasBias && !C().getType().isa<RankedTensorType>())) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
auto lhsTy = A().getType().cast<RankedTensorType>();
|
||||
auto rhsTy = B().getType().cast<RankedTensorType>();
|
||||
|
||||
|
@ -862,8 +884,10 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
|
|||
!scale().getType().isa<RankedTensorType>() ||
|
||||
!B().getType().isa<RankedTensorType>() ||
|
||||
!mean().getType().isa<RankedTensorType>() ||
|
||||
!var().getType().isa<RankedTensorType>())
|
||||
!var().getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor(s) not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inputTensorTy = X().getType().cast<RankedTensorType>();
|
||||
auto scaleTensorTy = scale().getType().cast<RankedTensorType>();
|
||||
|
@ -915,8 +939,15 @@ bool ONNXBatchNormalizationTestModeOp::inferShapes() {
|
|||
|
||||
bool ONNXReshapeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape tensor is specified.
|
||||
if (!shape().getType().isa<RankedTensorType>())
|
||||
if (!data().getType().isa<RankedTensorType>()) {
|
||||
emitError("Input data tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!shape().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto inputTensorTy = data().getType().cast<RankedTensorType>();
|
||||
auto shapeTensorTy = shape().getType().cast<RankedTensorType>();
|
||||
|
@ -991,8 +1022,10 @@ bool ONNXReshapeOp::inferShapes() {
|
|||
|
||||
bool ONNXTransposeOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
if (!data().getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Naive transposition which handles the default case of
|
||||
// reversing the shape of the tensor (similar to numpy.transpose).
|
||||
|
@ -1019,7 +1052,7 @@ bool ONNXTransposeOp::inferShapes() {
|
|||
|
||||
bool ONNXReduceMaxOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1034,7 +1067,7 @@ bool ONNXReduceMaxOp::inferShapes() {
|
|||
|
||||
bool ONNXReduceMinOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1049,7 +1082,7 @@ bool ONNXReduceMinOp::inferShapes() {
|
|||
|
||||
bool ONNXReduceProdOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1064,7 +1097,7 @@ bool ONNXReduceProdOp::inferShapes() {
|
|||
|
||||
bool ONNXReduceSumOp::inferShapes() {
|
||||
if (!getOperand().getType().isa<RankedTensorType>()) {
|
||||
emitError("Shape tensor not ranked");
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -1097,8 +1130,10 @@ bool ONNXConvOp::inferShapes() {
|
|||
// Cannot infer shape if no shape exists.
|
||||
if (!X().getType().isa<RankedTensorType>() ||
|
||||
!W().getType().isa<RankedTensorType>() ||
|
||||
(hasBias && !B().getType().isa<RankedTensorType>()))
|
||||
(hasBias && !B().getType().isa<RankedTensorType>())) {
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto xTy = X().getType().cast<RankedTensorType>();
|
||||
auto xShape = xTy.getShape();
|
||||
|
@ -1210,8 +1245,10 @@ bool ONNXConvOp::inferShapes() {
|
|||
|
||||
bool ONNXAveragePoolOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!X().getType().isa<RankedTensorType>())
|
||||
if (!X().getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get shape of input.
|
||||
auto xTy = X().getType().cast<RankedTensorType>();
|
||||
|
@ -1255,8 +1292,10 @@ bool ONNXAveragePoolOp::inferShapes() {
|
|||
|
||||
bool ONNXMaxPoolSingleOutOp::inferShapes() {
|
||||
// Cannot infer shape if no shape exists.
|
||||
if (!X().getType().isa<RankedTensorType>())
|
||||
if (!X().getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get shape of input.
|
||||
auto xTy = X().getType().cast<RankedTensorType>();
|
||||
|
@ -1364,8 +1403,10 @@ void ONNXPadConstantValuePadOp::build(Builder *builder, OperationState &state,
|
|||
// Unsqueeze
|
||||
|
||||
bool ONNXUnsqueezeOp::inferShapes() {
|
||||
if (!data().getType().isa<RankedTensorType>())
|
||||
if (!data().getType().isa<RankedTensorType>()) {
|
||||
emitError("Input tensor not ranked");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto operandTy = data().getType().cast<RankedTensorType>();
|
||||
int inRank = operandTy.getRank();
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
enum EmissionTargetType {
|
||||
EmitONNXBasic,
|
||||
EmitONNXIR,
|
||||
EmitMLIR,
|
||||
EmitLLVMIR,
|
||||
|
|
|
@ -37,7 +37,7 @@ public:
|
|||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||
if (!shape_op.inferShapes()) {
|
||||
op->emitError("unable to infer shape of operation without shape "
|
||||
"inference interface");
|
||||
"inference method");
|
||||
return signalPassFailure();
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -23,6 +23,9 @@ int main(int argc, char *argv[]) {
|
|||
llvm::cl::opt<EmissionTargetType> emissionTarget(
|
||||
llvm::cl::desc("Choose target to emit:"),
|
||||
llvm::cl::values(
|
||||
clEnumVal(EmitONNXBasic,
|
||||
"Ingest ONNX and emit the basic ONNX operations without"
|
||||
"inferred shapes."),
|
||||
clEnumVal(EmitONNXIR,
|
||||
"Ingest ONNX and emit corresponding ONNX dialect."),
|
||||
clEnumVal(EmitMLIR,
|
||||
|
@ -41,6 +44,7 @@ int main(int argc, char *argv[]) {
|
|||
processInputFile(inputFilename, emissionTarget, context, module);
|
||||
|
||||
mlir::PassManager pm(&context);
|
||||
if (emissionTarget >= EmitONNXIR)
|
||||
addONNXToMLIRPasses(pm);
|
||||
|
||||
if (emissionTarget >= EmitMLIR) {
|
||||
|
|
Loading…
Reference in New Issue