Enable inference for arbitrary number of instructions (#12)

* Fix shape inference.

* Remove comment.

* Remove worklist since it is not needed.
This commit is contained in:
Gheorghe-Teodor Bercea 2020-03-10 14:16:03 -04:00 committed by GitHub
parent 1882059ac9
commit ba02b90e0b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 29 deletions

View File

@ -34,40 +34,30 @@ public:
void runOnFunction() override {
auto f = getFunction();
// Populate the worklist with the operations that need shape inference:
// these are operations that return a dynamic shape.
llvm::SmallPtrSet<mlir::Operation *, 16> op_worklist;
// Iterate on the operations that need shape inference i.e the operations
// that return a dynamic shape.
f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op))
op_worklist.insert(op);
if (returnsDynamicShape(op)) {
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
shape_op.inferShapes();
} else {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}
});
// Iterate on the operations in the worklist until all operations have been
// inferred or no change happened (fix point).
while (!op_worklist.empty()) {
// Find the next operation ready for inference, that is an operation
// with all operands already resolved (non-generic).
auto nextop = llvm::find_if(op_worklist, returnsDynamicShape);
if (nextop == op_worklist.end())
break;
int64_t dynamicOperations = 0;
f.walk([&](mlir::Operation *op) {
if (returnsDynamicShape(op))
dynamicOperations++;
});
Operation *op = *nextop;
op_worklist.erase(op);
// Ask the operation to infer its output shapes.
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
shape_op.inferShapes();
} else {
op->emitError("unable to infer shape of operation without shape "
"inference interface");
return signalPassFailure();
}
}
// If the operation worklist isn't empty, this indicates a failure.
if (!op_worklist.empty()) {
// If any dynamic operations remain, this indicates a failure.
if (dynamicOperations != 0) {
f.emitError("Shape inference failed, ")
<< op_worklist.size() << " operations couldn't be inferred\n";
<< dynamicOperations << " operations couldn't be inferred\n";
signalPassFailure();
}