[MLIR] Move some walk()ing functions to the lmhlo::FusionOp API.
PiperOrigin-RevId: 344109366
This commit is contained in:
parent
85f92a1651
commit
3e01448481
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
|
|
|
@ -656,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
|||
let builders = [
|
||||
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
SmallVector<Value, 4> getInputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load.memref());
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getOutputBuffers() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.memref());
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getFusionParameters() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorLoadOp load) {
|
||||
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(load);
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> getFusionResults() {
|
||||
SmallVector<Value, 4> buffers;
|
||||
this->region().walk([&](TensorStoreOp store) {
|
||||
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||
buffers.push_back(store.tensor());
|
||||
});
|
||||
return buffers;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def TerminatorOp :
|
||||
|
|
Loading…
Reference in New Issue