[LMHLO] Simplify FusionOp::getInputBuffers() and friends.
- No need to walk the entire region, instead just iterate over the top level operations in the region attached to the fusion op. PiperOrigin-RevId: 366528833
This commit is contained in:
parent
38d0f96709
commit
0800423d27
|
@ -687,39 +687,35 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<Value, 4> getInputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](memref::TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load.memref());
|
||||
});
|
||||
SmallVector<Value> getInputBuffers() {
|
||||
SmallVector<Value> buffers;
|
||||
for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
|
||||
buffers.push_back(load.memref());
|
||||
}
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getOutputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](memref::TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.memref());
|
||||
});
|
||||
SmallVector<Value> getOutputBuffers() {
|
||||
SmallVector<Value> buffers;
|
||||
for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
|
||||
buffers.push_back(store.memref());
|
||||
}
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getFusionParameters() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](memref::TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load);
|
||||
});
|
||||
SmallVector<Value> getFusionParameters() {
|
||||
SmallVector<Value> buffers;
|
||||
for (auto load : region().front().getOps<memref::TensorLoadOp>()) {
|
||||
buffers.push_back(load);
|
||||
}
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getFusionResults() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](memref::TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.tensor());
|
||||
});
|
||||
SmallVector<Value> getFusionResults() {
|
||||
SmallVector<Value> buffers;
|
||||
for (auto store : region().front().getOps<memref::TensorStoreOp>()) {
|
||||
buffers.push_back(store.tensor());
|
||||
}
|
||||
return buffers;
|
||||
}
|
||||
}];
|
||||
|
|
Loading…
Reference in New Issue