Error fix3 (#145)

* removed warning missing return, dangling else

* fixed errors, made sure to return false in all shape inference failures

* shape inference use LogicalResults as return value

* format fixed

* format error

Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
Alexandre Eichenberger 2020-05-26 22:09:28 -04:00 committed by GitHub
parent c20aa6980e
commit 4f8fd9d1bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 293 additions and 377 deletions

View File

@ -490,6 +490,7 @@ Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter,
return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
} else { } else {
emitError(loc, "unsupported element type"); emitError(loc, "unsupported element type");
return nullptr;
} }
} }

View File

@ -258,12 +258,12 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
for (auto arg : loopBatchIVs) for (auto arg : loopBatchIVs)
loopBatchKNIVs.emplace_back(arg); loopBatchKNIVs.emplace_back(arg);
loopBatchKNIVs.emplace_back(loopKIVs[0]); loopBatchKNIVs.emplace_back(loopKIVs[0]);
if (BShape.size() >= 2) if (BShape.size() >= 2) {
if (AShape.size() >= 2) if (AShape.size() >= 2)
loopBatchKNIVs.emplace_back(loopMNIVs[1]); loopBatchKNIVs.emplace_back(loopMNIVs[1]);
else else
loopBatchKNIVs.emplace_back(loopMNIVs[0]); loopBatchKNIVs.emplace_back(loopMNIVs[0]);
}
// Matmul computation // Matmul computation
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs); auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs); auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);

View File

@ -32,6 +32,7 @@ public:
return LoopType::get(parser.getBuilder().getContext()); return LoopType::get(parser.getBuilder().getContext());
parser.emitError(parser.getCurrentLocation(), "Unknown type"); parser.emitError(parser.getCurrentLocation(), "Unknown type");
return nullptr;
} }
/// Print a type registered to this dialect. /// Print a type registered to this dialect.

File diff suppressed because it is too large Load Diff

View File

@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
let methods = [ let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.", InterfaceMethod<"Infer and set the output shape for the current operation.",
"bool", "inferShapes"> "LogicalResult", "inferShapes">
]; ];
} }

View File

@ -36,7 +36,7 @@ public:
f.walk([&](mlir::Operation *op) { f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op)) { if (returnsDynamicShape(op)) {
if (auto shape_op = dyn_cast<ShapeInference>(op)) { if (auto shape_op = dyn_cast<ShapeInference>(op)) {
if (!shape_op.inferShapes()) { if (failed(shape_op.inferShapes())) {
op->emitError("unable to infer shape of operation without shape " op->emitError("unable to infer shape of operation without shape "
"inference method"); "inference method");
return signalPassFailure(); return signalPassFailure();