Enable inference for arbitrary number of instructions (#12)
* Fix shape inference. * Remove comment. * Remove worklist since it is not needed.
This commit is contained in:
parent
1882059ac9
commit
ba02b90e0b
|
@ -34,40 +34,30 @@ public:
|
||||||
void runOnFunction() override {
|
void runOnFunction() override {
|
||||||
auto f = getFunction();
|
auto f = getFunction();
|
||||||
|
|
||||||
// Populate the worklist with the operations that need shape inference:
|
// Iterate on the operations that need shape inference i.e the operations
|
||||||
// these are operations that return a dynamic shape.
|
// that return a dynamic shape.
|
||||||
llvm::SmallPtrSet<mlir::Operation *, 16> op_worklist;
|
|
||||||
f.walk([&](mlir::Operation *op) {
|
f.walk([&](mlir::Operation *op) {
|
||||||
if (returnsDynamicShape(op))
|
if (returnsDynamicShape(op)) {
|
||||||
op_worklist.insert(op);
|
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
||||||
|
shape_op.inferShapes();
|
||||||
|
} else {
|
||||||
|
op->emitError("unable to infer shape of operation without shape "
|
||||||
|
"inference interface");
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Iterate on the operations in the worklist until all operations have been
|
int64_t dynamicOperations = 0;
|
||||||
// inferred or no change happened (fix point).
|
f.walk([&](mlir::Operation *op) {
|
||||||
while (!op_worklist.empty()) {
|
if (returnsDynamicShape(op))
|
||||||
// Find the next operation ready for inference, that is an operation
|
dynamicOperations++;
|
||||||
// with all operands already resolved (non-generic).
|
});
|
||||||
auto nextop = llvm::find_if(op_worklist, returnsDynamicShape);
|
|
||||||
if (nextop == op_worklist.end())
|
|
||||||
break;
|
|
||||||
|
|
||||||
Operation *op = *nextop;
|
// If any dynamic operations remain, this indicates a failure.
|
||||||
op_worklist.erase(op);
|
if (dynamicOperations != 0) {
|
||||||
|
|
||||||
// Ask the operation to infer its output shapes.
|
|
||||||
if (auto shape_op = dyn_cast<ShapeInference>(op)) {
|
|
||||||
shape_op.inferShapes();
|
|
||||||
} else {
|
|
||||||
op->emitError("unable to infer shape of operation without shape "
|
|
||||||
"inference interface");
|
|
||||||
return signalPassFailure();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the operation worklist isn't empty, this indicates a failure.
|
|
||||||
if (!op_worklist.empty()) {
|
|
||||||
f.emitError("Shape inference failed, ")
|
f.emitError("Shape inference failed, ")
|
||||||
<< op_worklist.size() << " operations couldn't be inferred\n";
|
<< dynamicOperations << " operations couldn't be inferred\n";
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue