Extend fusion root heuristic to also work in partially bufferized programs.
We now follow data flow though tensor_cast, tensor_load and tensor_to_memref operations. PiperOrigin-RevId: 342851104
This commit is contained in:
parent
d9113e0b4d
commit
0c7152e65c
|
@ -91,6 +91,31 @@ class LhloFuseLinalgPass
|
|||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_load = dyn_cast<TensorLoadOp>(definingOp)) {
|
||||
auto alias = tensor_load.memref();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_to_memref = dyn_cast<TensorToMemrefOp>(definingOp)) {
|
||||
auto alias = tensor_to_memref.tensor();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto tensor_cast = dyn_cast<TensorCastOp>(definingOp)) {
|
||||
auto alias = tensor_cast.source();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto regionInterface =
|
||||
|
|
|
@ -372,3 +372,58 @@ func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: inde
|
|||
// PLOOP: else
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: scf.yield
|
||||
|
||||
// -----
|
||||
|
||||
// Confirm that tiling information is passed through tensor_load, tensor_cast
|
||||
// and memref_to_tensor operations.
|
||||
func @tensor_ops(%arg0: memref<32xf32>, %arg1: memref<32xindex>)
|
||||
-> memref<?xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%1 = alloc() : memref<32xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%arg0 : memref<32xf32>) outs(%1 : memref<32xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = tensor_load %1 : memref<32xf32>
|
||||
%3 = tensor_cast %2 : tensor<32xf32> to tensor<?xf32>
|
||||
%4 = tensor_to_memref %3 : memref<?xf32>
|
||||
return %4 : memref<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_ops
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: tensor_load
|
||||
// CHECK: tensor_cast
|
||||
// CHECK: tensor_to_memref
|
||||
|
||||
// TILED-LABEL: func @tensor_ops
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: tensor_load
|
||||
// TILED: tensor_cast
|
||||
// TILED: tensor_to_memref
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @tensor_ops
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: tensor_load
|
||||
// PLOOP: tensor_cast
|
||||
// PLOOP: tensor_to_memref
|
||||
|
|
Loading…
Reference in New Issue