Extend fusion root heuristic to also work in partially bufferized programs.

We now follow data flow though tensor_cast, tensor_load and tensor_to_memref
operations.

PiperOrigin-RevId: 342851104
This commit is contained in:
Stephan Herhut 2020-11-17 06:34:11 -08:00 committed by TensorFlow MLIR Team
parent d9113e0b4d
commit 0c7152e65c
2 changed files with 80 additions and 0 deletions

View File

@ -91,6 +91,31 @@ class LhloFuseLinalgPass
if (result_buffers.insert(alias).second) { if (result_buffers.insert(alias).second) {
worklist.push_back(alias); worklist.push_back(alias);
} }
continue;
}
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
auto alias = tensor_load.memref();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
auto alias = tensor_to_memref.tensor();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
}
if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
auto alias = tensor_cast.source();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
continue;
} }
if (auto regionInterface = if (auto regionInterface =

View File

@ -372,3 +372,58 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
// PLOOP: else // PLOOP: else
// PLOOP: memref_reshape // PLOOP: memref_reshape
// PLOOP: scf.yield // PLOOP: scf.yield
// -----
// Confirm that tiling information is passed through tensor_load, tensor_cast
// and memref_to_tensor operations.
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
-> memref<?xf32> {
%c1 = constant 1 : index
%1 = alloc() : memref<32xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = tensor_load %1 : memref<32xf32>
%3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
%4 = tensor_to_memref %3 : memref<?xf32>
return %4 : memref<?xf32>
}
// CHECK-LABEL: func @tensor_ops
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: tensor_load
// CHECK: tensor_cast
// CHECK: tensor_to_memref
// TILED-LABEL: func @tensor_ops
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: tensor_load
// TILED: tensor_cast
// TILED: tensor_to_memref
// PLOOP-LABEL: func @tensor_ops
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: tensor_load
// PLOOP: tensor_cast
// PLOOP: tensor_to_memref