From 3e0144848112c57ebfaa83258b1472765d031b5b Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Tue, 24 Nov 2020 12:21:49 -0800 Subject: [PATCH] [MLIR] Move some walk()ing functions to the lmhlo::FusionOp API. PiperOrigin-RevId: 344109366 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h | 1 + include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 38 ++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 9dc6d7a..78e9c7e 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index ff052b5..3d5a4a3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -656,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let builders = [ OpBuilderDAG<(ins "ArrayRef":$attributes)> ]; + + let extraClassDeclaration = [{ + SmallVector getInputBuffers() { + SmallVector buffers; + this->region().walk([&](TensorLoadOp load) { + if (load.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(load.memref()); + }); + return buffers; + } + + SmallVector getOutputBuffers() { + SmallVector buffers; + this->region().walk([&](TensorStoreOp store) { + if (store.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(store.memref()); + }); + return buffers; + } + + SmallVector getFusionParameters() { + SmallVector buffers; + this->region().walk([&](TensorLoadOp load) { + if (load.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(load); + }); + return buffers; + } + + SmallVector getFusionResults() { + SmallVector buffers; + this->region().walk([&](TensorStoreOp store) { + if (store.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(store.tensor()); + }); + return buffers; + } + }]; } def TerminatorOp :