Consider aliases through control flow for lhlo-fuse-linalg.

This should handle scf.if and shape.assuming regions,
which we care about in kernel_gen.

PiperOrigin-RevId: 342584762
This commit is contained in:
Tres Popp 2020-11-16 01:11:53 -08:00 committed by TensorFlow MLIR Team
parent 1afada97ed
commit be9ae88eaa
2 changed files with 101 additions and 0 deletions

View File

@ -85,13 +85,41 @@ class LhloFuseLinalgPass
if (!definingOp) { if (!definingOp) {
continue; continue;
} }
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) { if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
auto alias = viewLike.getViewSource(); auto alias = viewLike.getViewSource();
if (result_buffers.insert(alias).second) { if (result_buffers.insert(alias).second) {
worklist.push_back(alias); worklist.push_back(alias);
} }
} }
if (auto regionInterface =
dyn_cast<RegionBranchOpInterface>(definingOp)) {
for (Region& region : regionInterface.getOperation()->getRegions()) {
// Only consider regions that can return to the parent region.
SmallVector<RegionSuccessor, 2> successorRegions;
regionInterface.getSuccessorRegions(region.getRegionNumber(),
successorRegions);
if (llvm::none_of(successorRegions, [&](auto successorRegion) {
return successorRegion.isParent();
}))
continue;
// Iterate over all immediate terminators and record the values
// corresponding to result_buffers of interest.
for (Block& block : region) {
if (block.empty()) continue;
Operation& operation = block.back();
if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
auto idx = result.dyn_cast<OpResult>().getResultNumber();
if (result_buffers.insert(operation.getOperand(idx)).second) {
worklist.push_back(operation.getOperand(idx));
} }
}
}
}
}
MLIRContext* ctx = func.getContext(); MLIRContext* ctx = func.getContext();
OpBuilder b(func); OpBuilder b(func);
func.walk([&](linalg::GenericOp generic_op) { func.walk([&](linalg::GenericOp generic_op) {

View File

@ -299,3 +299,76 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
// PLOOP: absf // PLOOP: absf
// PLOOP: memref_reshape // PLOOP: memref_reshape
// -----
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
// This test also uses memref_reshape, just to have a value to return through
// the if statement.
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
-> memref<*xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%1 = alloc(%arg2) : memref<?xf32>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%13 = absf %arg3 : f32
linalg.yield %13 : f32
}
%true = constant 1 : i1
%3 = scf.if %true -> memref<*xf32> {
%2 = memref_reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
} else {
%2 = memref_reshape %1(%arg1)
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
scf.yield %2 : memref<*xf32>
}
return %3 : memref<*xf32>
}
// CHECK-LABEL: func @branching_result
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: absf
// CHECK: scf.if
// CHECK: memref_reshape
// CHECK: scf.yield
// CHECK: else
// CHECK: memref_reshape
// CHECK: scf.yield
// TILED-LABEL: func @branching_result
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: absf
// TILED: scf.if
// TILED: memref_reshape
// TILED: scf.yield
// TILED: else
// TILED: memref_reshape
// TILED: scf.yield
// PLOOP-LABEL: func @branching_result
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: absf
// PLOOP: scf.if
// PLOOP: memref_reshape
// PLOOP: scf.yield
// PLOOP: else
// PLOOP: memref_reshape
// PLOOP: scf.yield