Consider aliases through control flow for lhlo-fuse-linalg.
This should handle scf.if and shape.assuming regions, which we care about in kernel_gen. PiperOrigin-RevId: 342584762
This commit is contained in:
parent
1afada97ed
commit
be9ae88eaa
|
@ -85,13 +85,41 @@ class LhloFuseLinalgPass
|
||||||
if (!definingOp) {
|
if (!definingOp) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
|
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
|
||||||
auto alias = viewLike.getViewSource();
|
auto alias = viewLike.getViewSource();
|
||||||
if (result_buffers.insert(alias).second) {
|
if (result_buffers.insert(alias).second) {
|
||||||
worklist.push_back(alias);
|
worklist.push_back(alias);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto regionInterface =
|
||||||
|
dyn_cast<RegionBranchOpInterface>(definingOp)) {
|
||||||
|
for (Region& region : regionInterface.getOperation()->getRegions()) {
|
||||||
|
// Only consider regions that can return to the parent region.
|
||||||
|
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||||
|
regionInterface.getSuccessorRegions(region.getRegionNumber(),
|
||||||
|
successorRegions);
|
||||||
|
if (llvm::none_of(successorRegions, [&](auto successorRegion) {
|
||||||
|
return successorRegion.isParent();
|
||||||
|
}))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
// Iterate over all immediate terminators and record the values
|
||||||
|
// corresponding to result_buffers of interest.
|
||||||
|
for (Block& block : region) {
|
||||||
|
if (block.empty()) continue;
|
||||||
|
Operation& operation = block.back();
|
||||||
|
if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
|
||||||
|
auto idx = result.dyn_cast<OpResult>().getResultNumber();
|
||||||
|
if (result_buffers.insert(operation.getOperand(idx)).second) {
|
||||||
|
worklist.push_back(operation.getOperand(idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MLIRContext* ctx = func.getContext();
|
MLIRContext* ctx = func.getContext();
|
||||||
OpBuilder b(func);
|
OpBuilder b(func);
|
||||||
func.walk([&](linalg::GenericOp generic_op) {
|
func.walk([&](linalg::GenericOp generic_op) {
|
||||||
|
|
|
@ -299,3 +299,76 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
// PLOOP: absf
|
// PLOOP: absf
|
||||||
// PLOOP: memref_reshape
|
// PLOOP: memref_reshape
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
|
||||||
|
// This test also uses memref_reshape, just to have a value to return through
|
||||||
|
// the if statement.
|
||||||
|
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||||
|
-> memref<*xf32> {
|
||||||
|
%c1 = constant 1 : index
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%1 = alloc(%arg2) : memref<?xf32>
|
||||||
|
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||||
|
affine_map<(d0) -> (d0)>],
|
||||||
|
iterator_types = ["parallel"]}
|
||||||
|
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
|
||||||
|
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||||
|
%13 = absf %arg3 : f32
|
||||||
|
linalg.yield %13 : f32
|
||||||
|
}
|
||||||
|
%true = constant 1 : i1
|
||||||
|
%3 = scf.if %true -> memref<*xf32> {
|
||||||
|
%2 = memref_reshape %1(%arg1)
|
||||||
|
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||||
|
scf.yield %2 : memref<*xf32>
|
||||||
|
} else {
|
||||||
|
%2 = memref_reshape %1(%arg1)
|
||||||
|
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||||
|
scf.yield %2 : memref<*xf32>
|
||||||
|
}
|
||||||
|
return %3 : memref<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @branching_result
|
||||||
|
// CHECK: %[[C1:.*]] = constant 1
|
||||||
|
// CHECK-NOT: linalg.generic
|
||||||
|
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||||
|
// CHECK-NOT: scf.for
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK: absf
|
||||||
|
// CHECK: scf.if
|
||||||
|
// CHECK: memref_reshape
|
||||||
|
// CHECK: scf.yield
|
||||||
|
// CHECK: else
|
||||||
|
// CHECK: memref_reshape
|
||||||
|
// CHECK: scf.yield
|
||||||
|
|
||||||
|
// TILED-LABEL: func @branching_result
|
||||||
|
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||||
|
// TILED-NOT: linalg.generic
|
||||||
|
// TILED: scf.for {{.*}} step %[[C2]]
|
||||||
|
// TILED-NOT: scf.for
|
||||||
|
// TILED: linalg.generic
|
||||||
|
// TILED: absf
|
||||||
|
// TILED: scf.if
|
||||||
|
// TILED: memref_reshape
|
||||||
|
// TILED: scf.yield
|
||||||
|
// TILED: else
|
||||||
|
// TILED: memref_reshape
|
||||||
|
// TILED: scf.yield
|
||||||
|
|
||||||
|
// PLOOP-LABEL: func @branching_result
|
||||||
|
// PLOOP-NOT: linalg.generic
|
||||||
|
// PLOOP: scf.parallel
|
||||||
|
// PLOOP-NOT: scf.parallel
|
||||||
|
// PLOOP: linalg.generic
|
||||||
|
// PLOOP: absf
|
||||||
|
// PLOOP: scf.if
|
||||||
|
// PLOOP: memref_reshape
|
||||||
|
// PLOOP: scf.yield
|
||||||
|
// PLOOP: else
|
||||||
|
// PLOOP: memref_reshape
|
||||||
|
// PLOOP: scf.yield
|
||||||
|
|
Loading…
Reference in New Issue