Extend hlo-fuse-linalg slightly to support aliased returns.

The fusion heuristic identifies the root of a fusion by checking whether an
output of a linalg operation is a function result. It did not consider outputs
flowing through aliasing operations (like casts).

PiperOrigin-RevId: 337479910
This commit is contained in:
Stephan Herhut 2020-10-16 04:03:33 -07:00 committed by TensorFlow MLIR Team
parent 6a91ac6d8f
commit 2e30b59ddc
2 changed files with 79 additions and 5 deletions

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
@ -73,6 +74,24 @@ class LhloFuseLinalgPass
result_buffers.insert(operand); result_buffers.insert(operand);
} }
} }
// Resolve aliasing operations (like casts) on the result to identify
// results. This only handles escaping results.
// TODO(herhut): Use BufferizeAliasAnalysis for this.
llvm::SmallVector<Value, 4> worklist(result_buffers.begin(),
result_buffers.end());
while (!worklist.empty()) {
Value result = worklist.pop_back_val();
auto definingOp = result.getDefiningOp();
if (!definingOp) {
continue;
}
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
auto alias = viewLike.getViewSource();
if (result_buffers.insert(alias).second) {
worklist.push_back(alias);
}
}
}
MLIRContext* ctx = func.getContext(); MLIRContext* ctx = func.getContext();
OpBuilder b(func); OpBuilder b(func);
OperationFolder folder(ctx); OperationFolder folder(ctx);

View File

@ -3,7 +3,8 @@
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} #pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() : memref<6x6xf32> %temp_result = alloc() : memref<6x6xf32>
@ -73,7 +74,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
} }
%1 = alloc() : memref<100x10xf32> %1 = alloc() : memref<100x10xf32>
linalg.generic { linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]} iterator_types = ["parallel", "parallel"]}
ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>)
outs(%1 : memref<100x10xf32>) { outs(%1 : memref<100x10xf32>) {
@ -83,7 +86,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
} }
dealloc %0 : memref<100x10xf32> dealloc %0 : memref<100x10xf32>
linalg.generic { linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]} iterator_types = ["parallel", "parallel"]}
ins(%1 : memref<100x10xf32>) ins(%1 : memref<100x10xf32>)
outs(%arg2 : memref<100x10xf32>) { outs(%arg2 : memref<100x10xf32>) {
@ -132,7 +136,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
// ----- // -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} #pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel", "parallel",
"parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() : memref<6x6x6x6xf32> %temp_result = alloc() : memref<6x6x6x6xf32>
@ -190,7 +196,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
// ----- // -----
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} #pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32> %temp_result = alloc() : memref<6x6xf32>
@ -244,3 +251,51 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
// PLOOP: addf // PLOOP: addf
// PLOOP: linalg.generic // PLOOP: linalg.generic
// PLOOP: mulf // PLOOP: mulf
// -----
func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%2 = lmhlo.reshape_memref_cast %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
return %2 : memref<*xf32>
}
// CHECK-LABEL: func @view_result
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: reshape_memref_cast
// TILED-LABEL: func @view_result
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: reshape_memref_cast
// PLOOP-LABEL: func @view_result
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: reshape_memref_cast