Extend hlo-fuse-linalg slightly to support aliased returns.
The fusion heuristic identifies the root of a fusion by checking whether an output of a linalg operation is a function result. It did not consider outputs flowing through aliasing operations (like casts). PiperOrigin-RevId: 337479910
This commit is contained in:
parent
6a91ac6d8f
commit
2e30b59ddc
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/SCF.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
|
@ -73,6 +74,24 @@ class LhloFuseLinalgPass
|
|||
result_buffers.insert(operand);
|
||||
}
|
||||
}
|
||||
// Resolve aliasing operations (like casts) on the result to identify
|
||||
// results. This only handles escaping results.
|
||||
// TODO(herhut): Use BufferizeAliasAnalysis for this.
|
||||
llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
|
||||
result_buffers.end());
|
||||
while (!worklist.empty()) {
|
||||
Value result = worklist.pop_back_val();
|
||||
auto definingOp = result.getDefiningOp();
|
||||
if (!definingOp) {
|
||||
continue;
|
||||
}
|
||||
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
|
||||
auto alias = viewLike.getViewSource();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
}
|
||||
}
|
||||
MLIRContext* ctx = func.getContext();
|
||||
OpBuilder b(func);
|
||||
OperationFolder folder(ctx);
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
|
@ -73,7 +74,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
}
|
||||
%1 = alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
|
||||
outs(%1 : memref<100x10xf32>) {
|
||||
|
@ -83,7 +86,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
}
|
||||
dealloc %0 : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||
affine_map<(d0, d1) -> (d0, d1)>],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
ins(%1 : memref<100x10xf32>)
|
||||
outs(%arg2 : memref<100x10xf32>) {
|
||||
|
@ -132,7 +136,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
|||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0],
|
||||
iterator_types = ["parallel", "parallel", "parallel",
|
||||
"parallel"]}
|
||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||
|
@ -190,7 +196,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
|||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
|
||||
iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
|
@ -244,3 +251,51 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
|||
// PLOOP: addf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
||||
// -----
|
||||
|
||||
func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
-> memref<*xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%1 = alloc(%arg2) : memref<?xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%2 = lmhlo.reshape_memref_cast %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
return %2 : memref<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @view_result
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: reshape_memref_cast
|
||||
|
||||
// TILED-LABEL: func @view_result
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: reshape_memref_cast
|
||||
|
||||
|
||||
// PLOOP-LABEL: func @view_result
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: reshape_memref_cast
|
||||
|
||||
|
|
Loading…
Reference in New Issue