diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index 9dc6d7a..78e9c7e 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index ff052b5..3d5a4a3 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -656,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">] let builders = [ OpBuilderDAG<(ins "ArrayRef":$attributes)> ]; + + let extraClassDeclaration = [{ + SmallVector getInputBuffers() { + SmallVector buffers; + this->region().walk([&](TensorLoadOp load) { + if (load.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(load.memref()); + }); + return buffers; + } + + SmallVector getOutputBuffers() { + SmallVector buffers; + this->region().walk([&](TensorStoreOp store) { + if (store.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(store.memref()); + }); + return buffers; + } + + SmallVector getFusionParameters() { + SmallVector buffers; + this->region().walk([&](TensorLoadOp load) { + if (load.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(load); + }); + return buffers; + } + + SmallVector getFusionResults() { + SmallVector buffers; + this->region().walk([&](TensorStoreOp store) { + if (store.memref().getParentRegion()->isProperAncestor(®ion())) + buffers.push_back(store.tensor()); + }); + return buffers; + } + }]; } def TerminatorOp :