Enable inference for arbitrary number of instructions (#12)

* Fix shape inference.

* Remove comment.

* Remove worklist since it is not needed.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-03-10 14:16:03 -04:00 committed by GitHub
parent 1882059ac9
commit ba02b90e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 29 deletions

View File

@ -34,27 +34,10 @@ public:
void runOnFunction() override { void runOnFunction() override {
auto f = getFunction(); auto f = getFunction();
// Populate the worklist with the operations that need shape inference: // Iterate on the operations that need shape inference i.e the operations
// these are operations that return a dynamic shape. // that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation *, 16> op_worklist;
f.walk([&](mlir::Operation *op) { f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op)) if (returnsDynamicShape(op)) {
op_worklist.insert(op);
});
// Iterate on the operations in the worklist until all operations have been
// inferred or no change happened (fix point).
while (!op_worklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(op_worklist, returnsDynamicShape);
if (nextop == op_worklist.end())
break;
Operation *op = *nextop;
op_worklist.erase(op);
// Ask the operation to infer its output shapes.
if (auto shape_op = dyn_cast<ShapeInference>(op)) { if (auto shape_op = dyn_cast<ShapeInference>(op)) {
shape_op.inferShapes(); shape_op.inferShapes();
} else { } else {
@ -63,11 +46,18 @@ public:
return signalPassFailure(); return signalPassFailure();
} }
} }
});
// If the operation worklist isn't empty, this indicates a failure. int64_t dynamicOperations = 0;
if (!op_worklist.empty()) { f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op))
dynamicOperations++;
});
// If any dynamic operations remain, this indicates a failure.
if (dynamicOperations != 0) {
f.emitError("Shape inference failed, ") f.emitError("Shape inference failed, ")
<< op_worklist.size() << " operations couldn't be inferred\n"; << dynamicOperations << " operations couldn't be inferred\n";
signalPassFailure(); signalPassFailure();
} }