[LMHLO] Simplify FusionOp::getInputBuffers() and friends.

- No need to walk the entire region, instead just iterate over the top level operations in
  the region attached to the fusion op.

PiperOrigin-RevId: 366528833
This commit is contained in:
Rahul Joshi 2021-04-02 15:54:45 -07:00 committed by TensorFlow MLIR Team
parent 38d0f96709
commit 0800423d27
1 changed files with 20 additions and 24 deletions

View File

@ -687,39 +687,35 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
]; ];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
SmallVector<Value, 4> getInputBuffers() { SmallVector<Value> getInputBuffers() {
SmallVector<Value, 4> buffers; SmallVector<Value> buffers;
this->region().walk([&](memref::TensorLoadOp load) { for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
if (load.memref().getParentRegion()->isProperAncestor(&region())) buffers.push_back(load.memref());
buffers.push_back(load.memref()); }
});
return buffers; return buffers;
} }
SmallVector<Value, 4> getOutputBuffers() { SmallVector<Value> getOutputBuffers() {
SmallVector<Value, 4> buffers; SmallVector<Value> buffers;
this->region().walk([&](memref::TensorStoreOp store) { for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
if (store.memref().getParentRegion()->isProperAncestor(&region())) buffers.push_back(store.memref());
buffers.push_back(store.memref()); }
});
return buffers; return buffers;
} }
SmallVector<Value, 4> getFusionParameters() { SmallVector<Value> getFusionParameters() {
SmallVector<Value, 4> buffers; SmallVector<Value> buffers;
this->region().walk([&](memref::TensorLoadOp load) { for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
if (load.memref().getParentRegion()->isProperAncestor(&region())) buffers.push_back(load);
buffers.push_back(load); }
});
return buffers; return buffers;
} }
SmallVector<Value, 4> getFusionResults() { SmallVector<Value> getFusionResults() {
SmallVector<Value, 4> buffers; SmallVector<Value> buffers;
this->region().walk([&](memref::TensorStoreOp store) { for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
if (store.memref().getParentRegion()->isProperAncestor(&region())) buffers.push_back(store.tensor());
buffers.push_back(store.tensor()); }
});
return buffers; return buffers;
} }
}]; }];