[MLIR][KernelGen] Better rank specialization for clusters containing `mhlo.select`

Merge `mhlo.select` into rank specialization clusters. Infer shape equalities
correctly from `mhlo.select` (and also from `mhlo.clamp`). This allows to lower
the relu kernel completely flattened.

PiperOrigin-RevId: 379925793
This commit is contained in:
A. Unique TensorFlower 2021-06-17 04:05:09 -07:00 committed by TensorFlow MLIR Team
parent 376da8592f
commit 9f47ff607b
2 changed files with 47 additions and 4 deletions

View File

@ -68,8 +68,7 @@ bool IsClusterable(Operation *op) {
if (op->getNumOperands() == 0) return false;
return (op->hasTrait<mlir::OpTrait::Elementwise>() &&
op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) ||
(op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>() &&
op->hasTrait<chlo::OpTrait::Broadcasting>());
op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>();
}
struct RankSpecializationClusterPattern : public RewritePattern {
@ -714,6 +713,8 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc,
}
// This is a very limited form of shape inference. It is correct but incomplete.
// TODO(frgossen): Infer shape equalities from surrounding shape constraints
// when these are generated.
SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
chlo::RankSpecializationClusterOp op) {
llvm::EquivalenceClasses<Value> eqs;
@ -735,12 +736,17 @@ SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences(
if (!nested_op.getOperands().empty() && !nested_op.getResults().empty())
eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0));
}
// TODO(frgossen): Replace this with a check for the appropriate trait when
// that is available.
}
// Find equalities through special knowledge of ops.
for (Operation &nested_op : op.getBody()->without_terminator()) {
if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) {
union_sets(
{select_op.on_true(), select_op.on_false(), select_op.getResult()});
}
if (auto clamp_op = llvm::dyn_cast<mhlo::ClampOp>(nested_op)) {
union_sets({clamp_op.operand(), clamp_op.getResult()});
}
}
// Convert to a list-like equivalence class representation.

View File

@ -615,3 +615,40 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>)
// CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]])
// CHECK-SCF: return %[[RES]]
// -----
// All shapes are equal, which is inferrable through the select op.
// CHECK-LABEL: @relu_grad
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]])
// CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>)
// CHECK: %[[TMP0:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}e+00 : f32}
// CHECK: %[[TMP1:.*]] = "mhlo.compare"(%[[ARG0_]], %[[TMP0]]) {comparison_direction = "GT"}
// CHECK: %[[TMP2:.*]] = "mhlo.select"(%[[TMP1]], %[[ARG1_]], %[[TMP0]])
// CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]])
// CHECK: return %[[RES]]
%0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32>
%1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1>
%2 = "mhlo.select"(%1, %arg1, %0) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %2 : tensor<*xf32>
}
// CHECK-SCF-LABEL: @relu_grad
// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>)
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S1]], %[[S0]]
// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]]
// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]]
// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]])
// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32}
// CHECK-SCF-DAG: %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"}
// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]])
// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]]
// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]])
// CHECK-SCF: return %[[RES]]