Consider aliases through control flow for lhlo-fuse-linalg.
This should handle scf.if and shape.assuming regions, which we care about in kernel_gen. PiperOrigin-RevId: 342584762
This commit is contained in:
parent
1afada97ed
commit
be9ae88eaa
|
@ -85,13 +85,41 @@ class LhloFuseLinalgPass
|
|||
if (!definingOp) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(definingOp)) {
|
||||
auto alias = viewLike.getViewSource();
|
||||
if (result_buffers.insert(alias).second) {
|
||||
worklist.push_back(alias);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto regionInterface =
|
||||
dyn_cast<RegionBranchOpInterface>(definingOp)) {
|
||||
for (Region& region : regionInterface.getOperation()->getRegions()) {
|
||||
// Only consider regions that can return to the parent region.
|
||||
SmallVector<RegionSuccessor, 2> successorRegions;
|
||||
regionInterface.getSuccessorRegions(region.getRegionNumber(),
|
||||
successorRegions);
|
||||
if (llvm::none_of(successorRegions, [&](auto successorRegion) {
|
||||
return successorRegion.isParent();
|
||||
}))
|
||||
continue;
|
||||
|
||||
// Iterate over all immediate terminators and record the values
|
||||
// corresponding to result_buffers of interest.
|
||||
for (Block& block : region) {
|
||||
if (block.empty()) continue;
|
||||
Operation& operation = block.back();
|
||||
if (!operation.hasTrait<OpTrait::ReturnLike>()) continue;
|
||||
auto idx = result.dyn_cast<OpResult>().getResultNumber();
|
||||
if (result_buffers.insert(operation.getOperand(idx)).second) {
|
||||
worklist.push_back(operation.getOperand(idx));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MLIRContext* ctx = func.getContext();
|
||||
OpBuilder b(func);
|
||||
func.walk([&](linalg::GenericOp generic_op) {
|
||||
|
|
|
@ -299,3 +299,76 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
|||
// PLOOP: absf
|
||||
// PLOOP: memref_reshape
|
||||
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Confirm that tiling information is passed through RegionBranchOpInterfaces.
|
||||
// This test also uses memref_reshape, just to have a value to return through
|
||||
// the if statement.
|
||||
func @branching_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index)
|
||||
-> memref<*xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%1 = alloc(%arg2) : memref<?xf32>
|
||||
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>,
|
||||
affine_map<(d0) -> (d0)>],
|
||||
iterator_types = ["parallel"]}
|
||||
ins(%arg0 : memref<?xf32>) outs(%1 : memref<?xf32>) {
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
%13 = absf %arg3 : f32
|
||||
linalg.yield %13 : f32
|
||||
}
|
||||
%true = constant 1 : i1
|
||||
%3 = scf.if %true -> memref<*xf32> {
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
scf.yield %2 : memref<*xf32>
|
||||
} else {
|
||||
%2 = memref_reshape %1(%arg1)
|
||||
: (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
|
||||
scf.yield %2 : memref<*xf32>
|
||||
}
|
||||
return %3 : memref<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @branching_result
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
// CHECK: scf.if
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: scf.yield
|
||||
// CHECK: else
|
||||
// CHECK: memref_reshape
|
||||
// CHECK: scf.yield
|
||||
|
||||
// TILED-LABEL: func @branching_result
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: absf
|
||||
// TILED: scf.if
|
||||
// TILED: memref_reshape
|
||||
// TILED: scf.yield
|
||||
// TILED: else
|
||||
// TILED: memref_reshape
|
||||
// TILED: scf.yield
|
||||
|
||||
// PLOOP-LABEL: func @branching_result
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: absf
|
||||
// PLOOP: scf.if
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: scf.yield
|
||||
// PLOOP: else
|
||||
// PLOOP: memref_reshape
|
||||
// PLOOP: scf.yield
|
||||
|
|
Loading…
Reference in New Issue