diff --git a/src/pass/shape_inference_pass.cpp b/src/pass/shape_inference_pass.cpp index 47826af..59bee64 100644 --- a/src/pass/shape_inference_pass.cpp +++ b/src/pass/shape_inference_pass.cpp @@ -34,40 +34,30 @@ public: void runOnFunction() override { auto f = getFunction(); - // Populate the worklist with the operations that need shape inference: - // these are operations that return a dynamic shape. - llvm::SmallPtrSet op_worklist; + // Iterate on the operations that need shape inference i.e the operations + // that return a dynamic shape. f.walk([&](mlir::Operation *op) { - if (returnsDynamicShape(op)) - op_worklist.insert(op); + if (returnsDynamicShape(op)) { + if (auto shape_op = dyn_cast(op)) { + shape_op.inferShapes(); + } else { + op->emitError("unable to infer shape of operation without shape " + "inference interface"); + return signalPassFailure(); + } + } }); - // Iterate on the operations in the worklist until all operations have been - // inferred or no change happened (fix point). - while (!op_worklist.empty()) { - // Find the next operation ready for inference, that is an operation - // with all operands already resolved (non-generic). - auto nextop = llvm::find_if(op_worklist, returnsDynamicShape); - if (nextop == op_worklist.end()) - break; + int64_t dynamicOperations = 0; + f.walk([&](mlir::Operation *op) { + if (returnsDynamicShape(op)) + dynamicOperations++; + }); - Operation *op = *nextop; - op_worklist.erase(op); - - // Ask the operation to infer its output shapes. - if (auto shape_op = dyn_cast(op)) { - shape_op.inferShapes(); - } else { - op->emitError("unable to infer shape of operation without shape " - "inference interface"); - return signalPassFailure(); - } - } - - // If the operation worklist isn't empty, this indicates a failure. - if (!op_worklist.empty()) { + // If any dynamic operations remain, this indicates a failure. + if (dynamicOperations != 0) { f.emitError("Shape inference failed, ") - << op_worklist.size() << " operations couldn't be inferred\n"; + << dynamicOperations << " operations couldn't be inferred\n"; signalPassFailure(); }