diff --git a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 6dc5b64..8f50ad0 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" @@ -73,6 +74,24 @@ class LhloFuseLinalgPass result_buffers.insert(operand); } } + // Resolve aliasing operations (like casts) on the result to identify + // results. This only handles escaping results. + // TODO(herhut): Use BufferizeAliasAnalysis for this. + llvm::SmallVector worklist(result_buffers.begin(), + result_buffers.end()); + while (!worklist.empty()) { + Value result = worklist.pop_back_val(); + auto definingOp = result.getDefiningOp(); + if (!definingOp) { + continue; + } + if (auto viewLike = dyn_cast(definingOp)) { + auto alias = viewLike.getViewSource(); + if (result_buffers.insert(alias).second) { + worklist.push_back(alias); + } + } + } MLIRContext* ctx = func.getContext(); OpBuilder b(func); OperationFolder folder(ctx); diff --git a/tests/lhlo-fuse-linalg.mlir b/tests/lhlo-fuse-linalg.mlir index 9a218b3..e51bdfe 100644 --- a/tests/lhlo-fuse-linalg.mlir +++ b/tests/lhlo-fuse-linalg.mlir @@ -3,7 +3,8 @@ // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) { %temp_result = alloc() : memref<6x6xf32> @@ -73,7 +74,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, } %1 = alloc() : memref<100x10xf32> linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0, %0 : memref<100x10xf32>, memref<100x10xf32>) outs(%1 : memref<100x10xf32>) { @@ -83,7 +86,8 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, } dealloc %0 : memref<100x10xf32> linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%1 : memref<100x10xf32>) outs(%arg2 : memref<100x10xf32>) { @@ -132,7 +136,9 @@ func @fusion_of_three(%arg0: memref<100x10xf32>, // ----- #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +#pointwise_4d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel", "parallel", + "parallel"]} func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>, %summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) { %temp_result = alloc() : memref<6x6x6x6xf32> @@ -190,7 +196,8 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32 // ----- #map0 = affine_map<(d0, d1) -> (d0, d1)> -#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} +#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], + iterator_types = ["parallel", "parallel"]} func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, %summand_2: memref<6x6xf32>) -> memref<6x6xf32> { %temp_result = alloc() : memref<6x6xf32> @@ -244,3 +251,51 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>, // PLOOP: addf // PLOOP: linalg.generic // PLOOP: mulf + +// ----- + +func @view_result(%arg0: memref, %arg1: memref, %arg2: index) + -> memref<*xf32> { + %c1 = constant 1 : index + %c0 = constant 0 : index + %1 = alloc(%arg2) : memref + linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : memref) outs(%1 : memref) { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %13 = absf %arg3 : f32 + linalg.yield %13 : f32 + } + %2 = lmhlo.reshape_memref_cast %1(%arg1) + : (memref, memref) -> memref<*xf32> + return %2 : memref<*xf32> +} + +// CHECK-LABEL: func @view_result +// CHECK: %[[C1:.*]] = constant 1 +// CHECK-NOT: linalg.generic +// CHECK: scf.for {{.*}} step %[[C1]] +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK: absf +// CHECK: reshape_memref_cast + +// TILED-LABEL: func @view_result +// TILED-DAG: %[[C2:.*]] = constant 2 +// TILED-NOT: linalg.generic +// TILED: scf.for {{.*}} step %[[C2]] +// TILED-NOT: scf.for +// TILED: linalg.generic +// TILED: absf +// TILED: reshape_memref_cast + + +// PLOOP-LABEL: func @view_result +// PLOOP-NOT: linalg.generic +// PLOOP: scf.parallel +// PLOOP-NOT: scf.parallel +// PLOOP: linalg.generic +// PLOOP: absf +// PLOOP: reshape_memref_cast +