Enable inference for arbitrary number of instructions (#12)
* Fix shape inference. * Remove comment. * Remove worklist since it is not needed.
This commit is contained in:
		
							parent
							
								
									1882059ac9
								
							
						
					
					
						commit
						ba02b90e0b
					
				|  | @ -34,40 +34,30 @@ public: | |||
|   void runOnFunction() override { | ||||
|     auto f = getFunction(); | ||||
| 
 | ||||
|     // Populate the worklist with the operations that need shape inference:
 | ||||
|     // these are operations that return a dynamic shape.
 | ||||
|     llvm::SmallPtrSet<mlir::Operation *, 16> op_worklist; | ||||
|     // Iterate on the operations that need shape inference i.e the operations
 | ||||
|     // that return a dynamic shape.
 | ||||
|     f.walk([&](mlir::Operation *op) { | ||||
|       if (returnsDynamicShape(op)) | ||||
|         op_worklist.insert(op); | ||||
|       if (returnsDynamicShape(op)) { | ||||
|         if (auto shape_op = dyn_cast<ShapeInference>(op)) { | ||||
|           shape_op.inferShapes(); | ||||
|         } else { | ||||
|           op->emitError("unable to infer shape of operation without shape " | ||||
|                         "inference interface"); | ||||
|           return signalPassFailure(); | ||||
|         } | ||||
|       } | ||||
|     }); | ||||
| 
 | ||||
|     // Iterate on the operations in the worklist until all operations have been
 | ||||
|     // inferred or no change happened (fix point).
 | ||||
|     while (!op_worklist.empty()) { | ||||
|       // Find the next operation ready for inference, that is an operation
 | ||||
|       // with all operands already resolved (non-generic).
 | ||||
|       auto nextop = llvm::find_if(op_worklist, returnsDynamicShape); | ||||
|       if (nextop == op_worklist.end()) | ||||
|         break; | ||||
|     int64_t dynamicOperations = 0; | ||||
|     f.walk([&](mlir::Operation *op) { | ||||
|       if (returnsDynamicShape(op)) | ||||
|         dynamicOperations++; | ||||
|     }); | ||||
| 
 | ||||
|       Operation *op = *nextop; | ||||
|       op_worklist.erase(op); | ||||
| 
 | ||||
|       // Ask the operation to infer its output shapes.
 | ||||
|       if (auto shape_op = dyn_cast<ShapeInference>(op)) { | ||||
|         shape_op.inferShapes(); | ||||
|       } else { | ||||
|         op->emitError("unable to infer shape of operation without shape " | ||||
|                       "inference interface"); | ||||
|         return signalPassFailure(); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // If the operation worklist isn't empty, this indicates a failure.
 | ||||
|     if (!op_worklist.empty()) { | ||||
|     // If any dynamic operations remain, this indicates a failure.
 | ||||
|     if (dynamicOperations != 0) { | ||||
|       f.emitError("Shape inference failed, ") | ||||
|           << op_worklist.size() << " operations couldn't be inferred\n"; | ||||
|           << dynamicOperations << " operations couldn't be inferred\n"; | ||||
|       signalPassFailure(); | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue