[MLIR] Move some walk()ing functions to the lmhlo::FusionOp API.
PiperOrigin-RevId: 344109366
This commit is contained in:
		
							parent
							
								
									85f92a1651
								
							
						
					
					
						commit
						3e01448481
					
				| 
						 | 
					@ -20,6 +20,7 @@ limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#include "llvm/ADT/StringRef.h"
 | 
					#include "llvm/ADT/StringRef.h"
 | 
				
			||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
 | 
					#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
 | 
				
			||||||
 | 
					#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
				
			||||||
#include "mlir/IR/Attributes.h"
 | 
					#include "mlir/IR/Attributes.h"
 | 
				
			||||||
#include "mlir/IR/Dialect.h"
 | 
					#include "mlir/IR/Dialect.h"
 | 
				
			||||||
#include "mlir/IR/Location.h"
 | 
					#include "mlir/IR/Location.h"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -656,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
 | 
				
			||||||
  let builders = [
 | 
					  let builders = [
 | 
				
			||||||
     OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
 | 
					     OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
 | 
				
			||||||
  ];
 | 
					  ];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let extraClassDeclaration = [{
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> getInputBuffers() {
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> buffers;
 | 
				
			||||||
 | 
					      this->region().walk([&](TensorLoadOp load) {
 | 
				
			||||||
 | 
					        if (load.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
				
			||||||
 | 
					          buffers.push_back(load.memref());
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					      return buffers;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> getOutputBuffers() {
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> buffers;
 | 
				
			||||||
 | 
					      this->region().walk([&](TensorStoreOp store) {
 | 
				
			||||||
 | 
					        if (store.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
				
			||||||
 | 
					          buffers.push_back(store.memref());
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					      return buffers;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> getFusionParameters() {
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> buffers;
 | 
				
			||||||
 | 
					      this->region().walk([&](TensorLoadOp load) {
 | 
				
			||||||
 | 
					        if (load.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
				
			||||||
 | 
					          buffers.push_back(load);
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					      return buffers;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    SmallVector<Value, 4> getFusionResults() {
 | 
				
			||||||
 | 
					      SmallVector<Value, 4> buffers;
 | 
				
			||||||
 | 
					      this->region().walk([&](TensorStoreOp store) {
 | 
				
			||||||
 | 
					        if (store.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
				
			||||||
 | 
					          buffers.push_back(store.tensor());
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					      return buffers;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }];
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def TerminatorOp :
 | 
					def TerminatorOp :
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue