Error fix3 (#145)
* removed warning missing return, dangling else * fixed errors, made sure to return false in all shape inference failures * shape inference use LogicalResults as return value * format fixed * format error Co-authored-by: Tian Jin <tjingrant@gmail.com>
This commit is contained in:
parent
c20aa6980e
commit
4f8fd9d1bf
|
@ -490,6 +490,7 @@ Value emitScalarOpFor<ONNXNegOp>(ConversionPatternRewriter &rewriter,
|
||||||
return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
|
return rewriter.create<mlir::SubIOp>(loc, zero, operand); // 0 - X = -X
|
||||||
} else {
|
} else {
|
||||||
emitError(loc, "unsupported element type");
|
emitError(loc, "unsupported element type");
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -258,12 +258,12 @@ struct ONNXMatMulOpLowering : public ConversionPattern {
|
||||||
for (auto arg : loopBatchIVs)
|
for (auto arg : loopBatchIVs)
|
||||||
loopBatchKNIVs.emplace_back(arg);
|
loopBatchKNIVs.emplace_back(arg);
|
||||||
loopBatchKNIVs.emplace_back(loopKIVs[0]);
|
loopBatchKNIVs.emplace_back(loopKIVs[0]);
|
||||||
if (BShape.size() >= 2)
|
if (BShape.size() >= 2) {
|
||||||
if (AShape.size() >= 2)
|
if (AShape.size() >= 2)
|
||||||
loopBatchKNIVs.emplace_back(loopMNIVs[1]);
|
loopBatchKNIVs.emplace_back(loopMNIVs[1]);
|
||||||
else
|
else
|
||||||
loopBatchKNIVs.emplace_back(loopMNIVs[0]);
|
loopBatchKNIVs.emplace_back(loopMNIVs[0]);
|
||||||
|
}
|
||||||
// Matmul computation
|
// Matmul computation
|
||||||
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
|
auto loadedA = rewriter.create<LoadOp>(loc, A, loopBatchMKIVs);
|
||||||
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
|
auto loadedB = rewriter.create<LoadOp>(loc, B, loopBatchKNIVs);
|
||||||
|
|
|
@ -32,6 +32,7 @@ public:
|
||||||
return LoopType::get(parser.getBuilder().getContext());
|
return LoopType::get(parser.getBuilder().getContext());
|
||||||
|
|
||||||
parser.emitError(parser.getCurrentLocation(), "Unknown type");
|
parser.emitError(parser.getCurrentLocation(), "Unknown type");
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Print a type registered to this dialect.
|
/// Print a type registered to this dialect.
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -25,7 +25,7 @@ def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
|
||||||
|
|
||||||
let methods = [
|
let methods = [
|
||||||
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
InterfaceMethod<"Infer and set the output shape for the current operation.",
|
||||||
"bool", "inferShapes">
|
"LogicalResult", "inferShapes">
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ public:
|
||||||
f.walk([&](mlir::Operation *op) {
|
f.walk([&](mlir::Operation *op) {
|
||||||
if (returnsDynamicShape(op)) {
|
if (returnsDynamicShape(op)) {
|
||||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||||
if (!shape_op.inferShapes()) {
|
if (failed(shape_op.inferShapes())) {
|
||||||
op->emitError("unable to infer shape of operation without shape "
|
op->emitError("unable to infer shape of operation without shape "
|
||||||
"inference method");
|
"inference method");
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
Loading…
Reference in New Issue