[MLIR] Move some walk()ing functions to the lmhlo::FusionOp API.

PiperOrigin-RevId: 344109366
This commit is contained in:
Tim Shen 2020-11-24 12:21:49 -08:00 committed by TensorFlow MLIR Team
parent 85f92a1651
commit 3e01448481
2 changed files with 39 additions and 0 deletions

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "llvm/ADT/StringRef.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Location.h"

View File

@ -656,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
let builders = [
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
];
let extraClassDeclaration = [{
SmallVector<Value, 4> getInputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load.memref());
});
return buffers;
}
SmallVector<Value, 4> getOutputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.memref());
});
return buffers;
}
SmallVector<Value, 4> getFusionParameters() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load);
});
return buffers;
}
SmallVector<Value, 4> getFusionResults() {
SmallVector<Value, 4> buffers;
this->region().walk([&](TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.tensor());
});
return buffers;
}
}];
}
def TerminatorOp :