[LMHLO] Simplify FusionOp::getInputBuffers() and friends.

- No need to walk the entire region, instead just iterate over the top level operations in
  the region attached to the fusion op.

PiperOrigin-RevId: 366528833
This commit is contained in:
Rahul Joshi 2021-04-02 15:54:45 -07:00 committed by TensorFlow MLIR Team
parent 38d0f96709
commit 0800423d27
1 changed files with 20 additions and 24 deletions

View File

@ -687,39 +687,35 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
];
let extraClassDeclaration = [{
SmallVector<Value, 4> getInputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](memref::TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load.memref());
});
SmallVector<Value> getInputBuffers() {
SmallVector<Value> buffers;
for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
buffers.push_back(load.memref());
}
return buffers;
}
SmallVector<Value, 4> getOutputBuffers() {
SmallVector<Value, 4> buffers;
this->region().walk([&](memref::TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.memref());
});
SmallVector<Value> getOutputBuffers() {
SmallVector<Value> buffers;
for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
buffers.push_back(store.memref());
}
return buffers;
}
SmallVector<Value, 4> getFusionParameters() {
SmallVector<Value, 4> buffers;
this->region().walk([&](memref::TensorLoadOp load) {
if (load.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(load);
});
SmallVector<Value> getFusionParameters() {
SmallVector<Value> buffers;
for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
buffers.push_back(load);
}
return buffers;
}
SmallVector<Value, 4> getFusionResults() {
SmallVector<Value, 4> buffers;
this->region().walk([&](memref::TensorStoreOp store) {
if (store.memref().getParentRegion()->isProperAncestor(&region()))
buffers.push_back(store.tensor());
});
SmallVector<Value> getFusionResults() {
SmallVector<Value> buffers;
for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
buffers.push_back(store.tensor());
}
return buffers;
}
}];