Extend hlo-fuse-linalg slightly to support aliased returns.
The fusion heuristic identifies the root of a fusion by checking whether an output of a linalg operation is a function result. It did not consider outputs flowing through aliasing operations (like casts). PiperOrigin-RevId: 337479910
This commit is contained in:
parent
6a91ac6d8f
commit
2e30b59ddc
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
#include "mlir/Dialect/SCF/SCF.h"
|
#include "mlir/Dialect/SCF/SCF.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||||
|
#include "mlir/Interfaces/ViewLikeInterface.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Transforms/FoldUtils.h"
|
#include "mlir/Transforms/FoldUtils.h"
|
||||||
|
|
||||||
|
@ -73,6 +74,24 @@ class LhloFuseLinalgPass
|
||||||
result_buffers.insert(operand);
|
result_buffers.insert(operand);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Resolve aliasing operations (like casts) on the result to identify
|
||||||
|
// results. This only handles escaping results.
|
||||||
|
// TODO(herhut): Use BufferizeAliasAnalysis for this.
|
||||||
|
llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
|
||||||
|
result_buffers.end());
|
||||||
|
while (!worklist.empty()) {
|
||||||
|
Value result = worklist.pop_back_val();
|
||||||
|
auto definingOp = result.getDefiningOp();
|
||||||
|
if (!definingOp) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
|
||||||
|
auto alias = viewLike.getViewSource();
|
||||||
|
if (result_buffers.insert(alias).second) {
|
||||||
|
worklist.push_back(alias);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
MLIRContext* ctx = func.getContext();
|
MLIRContext* ctx = func.getContext();
|
||||||
OpBuilder b(func);
|
OpBuilder b(func);
|
||||||
OperationFolder folder(ctx);
|
OperationFolder folder(ctx);
|
||||||
|
|
|
@ -3,7 +3,8 @@
|
||||||
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
|
||||||
|
iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||||
%temp_result = alloc() : memref<6x6xf32>
|
%temp_result = alloc() : memref<6x6xf32>
|
||||||
|
@ -73,7 +74,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
}
|
}
|
||||||
%1 = alloc() : memref<100x10xf32>
|
%1 = alloc() : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
iterator_types = ["parallel", "parallel"]}
|
iterator_types = ["parallel", "parallel"]}
|
||||||
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
|
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
|
||||||
outs(%1 : memref<100x10xf32>) {
|
outs(%1 : memref<100x10xf32>) {
|
||||||
|
@ -83,7 +86,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
}
|
}
|
||||||
dealloc %0 : memref<100x10xf32>
|
dealloc %0 : memref<100x10xf32>
|
||||||
linalg.generic {
|
linalg.generic {
|
||||||
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
|
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
|
||||||
|
affine_map<(d0, d1) -> (d0, d1)>],
|
||||||
iterator_types = ["parallel", "parallel"]}
|
iterator_types = ["parallel", "parallel"]}
|
||||||
ins(%1 : memref<100x10xf32>)
|
ins(%1 : memref<100x10xf32>)
|
||||||
outs(%arg2 : memref<100x10xf32>) {
|
outs(%arg2 : memref<100x10xf32>) {
|
||||||
|
@ -132,7 +136,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||||
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0],
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel",
|
||||||
|
"parallel"]}
|
||||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||||
|
@ -190,7 +196,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
|
||||||
|
iterator_types = ["parallel", "parallel"]}
|
||||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||||
%temp_result = alloc() : memref<6x6xf32>
|
%temp_result = alloc() : memref<6x6xf32>
|
||||||
|
@ -244,3 +251,51 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||||
// PLOOP: addf
|
// PLOOP: addf
|
||||||
// PLOOP: linalg.generic
|
// PLOOP: linalg.generic
|
||||||
// PLOOP: mulf
|
// PLOOP: mulf
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
|
-> memref<*xf32> {
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%1 = alloc(%arg2) : memref<?xf32>
|
||||||
|
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||||
|
affine_map<(d0) -> (d0)>],
|
||||||
|
iterator_types = ["parallel"]}
|
||||||
|
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
|
||||||
|
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||||
|
%13 = absf %arg3 : f32
|
||||||
|
linalg.yield %13 : f32
|
||||||
|
}
|
||||||
|
%2 = lmhlo.reshape_memref_cast %1(%arg1)
|
||||||
|
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||||
|
return %2 : memref<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @view_result
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
|
// CHECK-NOT: linalg.generic
|
||||||
|
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK-NOT: scf.for
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: absf
|
||||||
|
// CHECK: reshape_memref_cast
|
||||||
|
|
||||||
|
// TILED-LABEL: func @view_result
|
||||||
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// TILED-NOT: linalg.generic
|
||||||
|
// TILED: scf.for {{.*}} step %[[C2]]
|
||||||
|
// TILED-NOT: scf.for
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: absf
|
||||||
|
// TILED: reshape_memref_cast
|
||||||
|
|
||||||
|
|
||||||
|
// PLOOP-LABEL: func @view_result
|
||||||
|
// PLOOP-NOT: linalg.generic
|
||||||
|
// PLOOP: scf.parallel
|
||||||
|
// PLOOP-NOT: scf.parallel
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: absf
|
||||||
|
// PLOOP: reshape_memref_cast
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue