[MLIR] Move some walk()ing functions to the lmhlo::FusionOp API.
PiperOrigin-RevId: 344109366
This commit is contained in:
parent
85f92a1651
commit
3e01448481
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_structs.h"
|
||||||
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
#include "mlir/IR/Location.h"
|
#include "mlir/IR/Location.h"
|
||||||
|
|
|
@ -656,6 +656,44 @@ def FusionOp : LHLO_Op<"fusion", [SingleBlockImplicitTerminator<"TerminatorOp">]
|
||||||
let builders = [
|
let builders = [
|
||||||
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
|
OpBuilderDAG<(ins "ArrayRef<NamedAttribute>":$attributes)>
|
||||||
];
|
];
|
||||||
|
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
SmallVector<Value, 4> getInputBuffers() {
|
||||||
|
SmallVector<Value, 4> buffers;
|
||||||
|
this->region().walk([&](TensorLoadOp load) {
|
||||||
|
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
|
buffers.push_back(load.memref());
|
||||||
|
});
|
||||||
|
return buffers;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> getOutputBuffers() {
|
||||||
|
SmallVector<Value, 4> buffers;
|
||||||
|
this->region().walk([&](TensorStoreOp store) {
|
||||||
|
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
|
buffers.push_back(store.memref());
|
||||||
|
});
|
||||||
|
return buffers;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> getFusionParameters() {
|
||||||
|
SmallVector<Value, 4> buffers;
|
||||||
|
this->region().walk([&](TensorLoadOp load) {
|
||||||
|
if (load.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
|
buffers.push_back(load);
|
||||||
|
});
|
||||||
|
return buffers;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 4> getFusionResults() {
|
||||||
|
SmallVector<Value, 4> buffers;
|
||||||
|
this->region().walk([&](TensorStoreOp store) {
|
||||||
|
if (store.memref().getParentRegion()->isProperAncestor(®ion()))
|
||||||
|
buffers.push_back(store.tensor());
|
||||||
|
});
|
||||||
|
return buffers;
|
||||||
|
}
|
||||||
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def TerminatorOp :
|
def TerminatorOp :
|
||||||
|
|
Loading…
Reference in New Issue