[MLIR][KernelGen] Better rank specialization for clusters containing `mhlo.select`
Merge `mhlo.select` into rank specialization clusters. Infer shape equalities correctly from `mhlo.select` (and also from `mhlo.clamp`). This allows to lower the relu kernel completely flattened. PiperOrigin-RevId: 379925793
This commit is contained in:
parent
376da8592f
commit
9f47ff607b
|
@ -68,8 +68,7 @@ bool IsClusterable(Operation *op) {
|
|||
if (op->getNumOperands() == 0) return false;
|
||||
return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
|
||||
(op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>() &&
|
||||
op->hasTrait<chlo::OpTrait::Broadcasting>());
|
||||
op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>();
|
||||
}
|
||||
|
||||
struct RankSpecializationClusterPattern : public RewritePattern {
|
||||
|
@ -714,6 +713,8 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc,
|
|||
}
|
||||
|
||||
// This is a very limited form of shape inference. It is correct but incomplete.
|
||||
// TODO(frgossen): Infer shape equalities from surrounding shape constraints
|
||||
// when these are generated.
|
||||
SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
|
||||
chlo::RankSpecializationClusterOp op) {
|
||||
llvm::EquivalenceClasses<Value> eqs;
|
||||
|
@ -735,12 +736,17 @@ SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
|
|||
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
|
||||
eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0));
|
||||
}
|
||||
// TODO(frgossen): Replace this with a check for the appropriate trait when
|
||||
// that is available.
|
||||
}
|
||||
|
||||
// Find equalities through special knowledge of ops.
|
||||
for (Operation &nested_op : op.getBody()->without_terminator()) {
|
||||
if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) {
|
||||
union_sets(
|
||||
{select_op.on_true(), select_op.on_false(), select_op.getResult()});
|
||||
}
|
||||
if (auto clamp_op = llvm::dyn_cast<mhlo::ClampOp>(nested_op)) {
|
||||
union_sets({clamp_op.operand(), clamp_op.getResult()});
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to a list-like equivalence class representation.
|
||||
|
|
|
@ -615,3 +615,40 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
|
|||
// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9
|
||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]])
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
||||
// -----
|
||||
|
||||
// All shapes are equal, which is inferrable through the select op.
|
||||
// CHECK-LABEL: @relu_grad
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||
func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]])
|
||||
// CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>)
|
||||
// CHECK: %[[TMP0:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}e+00 : f32}
|
||||
// CHECK: %[[TMP1:.*]] = "mhlo.compare"(%[[ARG0_]], %[[TMP0]]) {comparison_direction = "GT"}
|
||||
// CHECK: %[[TMP2:.*]] = "mhlo.select"(%[[TMP1]], %[[ARG1_]], %[[TMP0]])
|
||||
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
|
||||
%2 = "mhlo.select"(%1, %arg1, %0) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-SCF-LABEL: @relu_grad
|
||||
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S1]], %[[S0]]
|
||||
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]]
|
||||
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
|
||||
// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
|
||||
// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32}
|
||||
// CHECK-SCF-DAG: %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"}
|
||||
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]])
|
||||
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]]
|
||||
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]])
|
||||
// CHECK-SCF: return %[[RES]]
|
||||
|
|
Loading…
Reference in New Issue