Error fix3 (#145)
* removed warning missing return, dangling else * fixed errors, made sure to return false in all shape inference failures * shape inference use LogicalResults as return value * format fixed * format error Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
c20aa6980e
commit
4f8fd9d1bf
|
@ -490,6 +490,7 @@ Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter,
|
|||
return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
|
||||
} else {
|
||||
emitError(loc, "unsupported element type");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -258,12 +258,12 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
|||
for (auto arg : loopBatchIVs)
|
||||
loopBatchKNIVs.emplace_back(arg);
|
||||
loopBatchKNIVs.emplace_back(loopKIVs[0]);
|
||||
if (BShape.size() >= 2)
|
||||
if (BShape.size() >= 2) {
|
||||
if (AShape.size() >= 2)
|
||||
loopBatchKNIVs.emplace_back(loopMNIVs[1]);
|
||||
else
|
||||
loopBatchKNIVs.emplace_back(loopMNIVs[0]);
|
||||
|
||||
}
|
||||
// Matmul computation
|
||||
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
|
||||
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
|
||||
|
|
|
@ -32,6 +32,7 @@ public:
|
|||
return LoopType::get(parser.getBuilder().getContext());
|
||||
|
||||
parser.emitError(parser.getCurrentLocation(), "Unknown type");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
|||
|
||||
let methods = [
|
||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||
"bool", "inferShapes">
|
||||
"LogicalResult", "inferShapes">
|
||||
];
|
||||
}
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ public:
|
|||
f.walk([&](mlir::Operation *op) {
|
||||
if (returnsDynamicShape(op)) {
|
||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||
if (!shape_op.inferShapes()) {
|
||||
if (failed(shape_op.inferShapes())) {
|
||||
op->emitError("unable to infer shape of operation without shape "
|
||||
"inference method");
|
||||
return signalPassFailure();
|
||||
|
|
Loading…
Reference in New Issue