Enable inference for arbitrary number of instructions (#12)
* Fix shape inference. * Remove comment. * Remove worklist since it is not needed.
This commit is contained in:
parent
1882059ac9
commit
ba02b90e0b
|
@ -34,27 +34,10 @@ public:
|
|||
void runOnFunction() override {
|
||||
auto f = getFunction();
|
||||
|
||||
// Populate the worklist with the operations that need shape inference:
|
||||
// these are operations that return a dynamic shape.
|
||||
llvm::SmallPtrSet<mlir::Operation *, 16> op_worklist;
|
||||
// Iterate on the operations that need shape inference i.e the operations
|
||||
// that return a dynamic shape.
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (returnsDynamicShape(op))
|
||||
op_worklist.insert(op);
|
||||
});
|
||||
|
||||
// Iterate on the operations in the worklist until all operations have been
|
||||
// inferred or no change happened (fix point).
|
||||
while (!op_worklist.empty()) {
|
||||
// Find the next operation ready for inference, that is an operation
|
||||
// with all operands already resolved (non-generic).
|
||||
auto nextop = llvm::find_if(op_worklist, returnsDynamicShape);
|
||||
if (nextop == op_worklist.end())
|
||||
break;
|
||||
|
||||
Operation *op = *nextop;
|
||||
op_worklist.erase(op);
|
||||
|
||||
// Ask the operation to infer its output shapes.
|
||||
if (returnsDynamicShape(op)) {
|
||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||
shape_op.inferShapes();
|
||||
} else {
|
||||
|
@ -63,11 +46,18 @@ public:
|
|||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// If the operation worklist isn't empty, this indicates a failure.
|
||||
if (!op_worklist.empty()) {
|
||||
int64_t dynamicOperations = 0;
|
||||
f.walk([&](mlir::Operation *op) {
|
||||
if (returnsDynamicShape(op))
|
||||
dynamicOperations++;
|
||||
});
|
||||
|
||||
// If any dynamic operations remain, this indicates a failure.
|
||||
if (dynamicOperations != 0) {
|
||||
f.emitError("Shape inference failed, ")
|
||||
<< op_worklist.size() << " operations couldn't be inferred\n";
|
||||
<< dynamicOperations << " operations couldn't be inferred\n";
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue