Update lhlo to use the new structured op interface.
Replace deprecated methods in lhlo_fuse_linalg.cc. The new structured op interface has been introduced in https://reviews.llvm.org/D103394. PiperOrigin-RevId: 377875452
This commit is contained in:
parent
ade873a5e0
commit
fc723380e6
|
@ -156,8 +156,8 @@ class LhloFuseLinalgPass
|
||||||
tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
|
tile_sizes = SmallVector<int64_t, 2>(generic_op.getNumLoops(), 1);
|
||||||
}
|
}
|
||||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||||
for (const Value result : op.getOutputBuffers()) {
|
for (OpOperand* op_operand : op.getOutputBufferOperands()) {
|
||||||
if (!result_buffers.count(result)) continue;
|
if (!result_buffers.count(op_operand->get())) continue;
|
||||||
if (tileGenericOp(op, tile_sizes, &b)) {
|
if (tileGenericOp(op, tile_sizes, &b)) {
|
||||||
generic_op.erase();
|
generic_op.erase();
|
||||||
return;
|
return;
|
||||||
|
@ -172,10 +172,10 @@ class LhloFuseLinalgPass
|
||||||
SmallVector<LinalgOp, 8> linalg_ops;
|
SmallVector<LinalgOp, 8> linalg_ops;
|
||||||
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
|
func.walk([&](LinalgOp op) { linalg_ops.push_back(op); });
|
||||||
for (LinalgOp op : llvm::reverse(linalg_ops)) {
|
for (LinalgOp op : llvm::reverse(linalg_ops)) {
|
||||||
for (OpOperand& inputOperand : op.getInputOpOperands()) {
|
for (OpOperand* inputOperand : op.getInputOperands()) {
|
||||||
linalg::Aliases aliases;
|
linalg::Aliases aliases;
|
||||||
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
|
linalg::LinalgDependenceGraph graph(aliases, linalg_ops);
|
||||||
if (auto info = fuseProducerOfBuffer(b, inputOperand, graph)) {
|
if (auto info = fuseProducerOfBuffer(b, *inputOperand, graph)) {
|
||||||
auto originalOp = info->originalProducer.getOperation();
|
auto originalOp = info->originalProducer.getOperation();
|
||||||
erase_set.insert(originalOp);
|
erase_set.insert(originalOp);
|
||||||
auto originalOpInLinalgOpsVector = std::find_if(
|
auto originalOpInLinalgOpsVector = std::find_if(
|
||||||
|
|
Loading…
Reference in New Issue