[MLIR] Move some walk()ing functions to the lmhlo::FusionOp API.
PiperOrigin-RevId: 344109366
This commit is contained in:
		
							parent
							
								
									85f92a1651
								
							
						
					
					
						commit
						3e01448481
					
				| 
						 | 
				
			
			@ -20,6 +20,7 @@ limitations under the License.
 | 
			
		|||
 | 
			
		||||
#include "llvm/ADT/StringRef.h"
 | 
			
		||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
 | 
			
		||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
 | 
			
		||||
#include "mlir/IR/Attributes.h"
 | 
			
		||||
#include "mlir/IR/Dialect.h"
 | 
			
		||||
#include "mlir/IR/Location.h"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -656,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
 | 
			
		|||
  let builders = [
 | 
			
		||||
     OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  let extraClassDeclaration = [{
 | 
			
		||||
    SmallVector<Value, 4> getInputBuffers() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorLoadOp load) {
 | 
			
		||||
        if (load.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(load.memref());
 | 
			
		||||
      });
 | 
			
		||||
      return buffers;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SmallVector<Value, 4> getOutputBuffers() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorStoreOp store) {
 | 
			
		||||
        if (store.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(store.memref());
 | 
			
		||||
      });
 | 
			
		||||
      return buffers;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SmallVector<Value, 4> getFusionParameters() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorLoadOp load) {
 | 
			
		||||
        if (load.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(load);
 | 
			
		||||
      });
 | 
			
		||||
      return buffers;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    SmallVector<Value, 4> getFusionResults() {
 | 
			
		||||
      SmallVector<Value, 4> buffers;
 | 
			
		||||
      this->region().walk([&](TensorStoreOp store) {
 | 
			
		||||
        if (store.memref().getParentRegion()->isProperAncestor(®ion()))
 | 
			
		||||
          buffers.push_back(store.tensor());
 | 
			
		||||
      });
 | 
			
		||||
      return buffers;
 | 
			
		||||
    }
 | 
			
		||||
  }];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
def TerminatorOp :
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue