diff --git a/lib/Dialect/mhlo/transforms/rank_specialization.cc b/lib/Dialect/mhlo/transforms/rank_specialization.cc index cd9eb00..1ef3f62 100644 --- a/lib/Dialect/mhlo/transforms/rank_specialization.cc +++ b/lib/Dialect/mhlo/transforms/rank_specialization.cc @@ -68,8 +68,7 @@ bool IsClusterable(Operation *op) { if (op->getNumOperands() == 0) return false; return (op->hasTrait() && op->hasTrait()) || - (op->hasTrait() && - op->hasTrait()); + op->hasTrait(); } struct RankSpecializationClusterPattern : public RewritePattern { @@ -714,6 +713,8 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc, } // This is a very limited form of shape inference. It is correct but incomplete. +// TODO(frgossen): Infer shape equalities from surrounding shape constraints +// when these are generated. SmallVector, 4> FindNonScalarShapeEquivalences( chlo::RankSpecializationClusterOp op) { llvm::EquivalenceClasses eqs; @@ -735,12 +736,17 @@ SmallVector, 4> FindNonScalarShapeEquivalences( if (!nested_op.getOperands().empty() && !nested_op.getResults().empty()) eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0)); } - // TODO(frgossen): Replace this with a check for the appropriate trait when - // that is available. + } + + // Find equalities through special knowledge of ops. + for (Operation &nested_op : op.getBody()->without_terminator()) { if (auto select_op = llvm::dyn_cast(nested_op)) { union_sets( {select_op.on_true(), select_op.on_false(), select_op.getResult()}); } + if (auto clamp_op = llvm::dyn_cast(nested_op)) { + union_sets({clamp_op.operand(), clamp_op.getResult()}); + } } // Convert to a list-like equivalence class representation. diff --git a/tests/rank-specialization.mlir b/tests/rank-specialization.mlir index c71b71b..fba118e 100644 --- a/tests/rank-specialization.mlir +++ b/tests/rank-specialization.mlir @@ -615,3 +615,40 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) // CHECK-SCF-DAG: %[[RES_S:.*]] = shape.broadcast %8, %9 // CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]]) // CHECK-SCF: return %[[RES]] + +// ----- + +// All shapes are equal, which is inferrable through the select op. +// CHECK-LABEL: @relu_grad +// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) +func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]]) + // CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>) + // CHECK: %[[TMP0:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}e+00 : f32} + // CHECK: %[[TMP1:.*]] = "mhlo.compare"(%[[ARG0_]], %[[TMP0]]) {comparison_direction = "GT"} + // CHECK: %[[TMP2:.*]] = "mhlo.select"(%[[TMP1]], %[[ARG1_]], %[[TMP0]]) + // CHECK: "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) + // CHECK: return %[[RES]] + %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> + %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> + %2 = "mhlo.select"(%1, %arg1, %0) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + return %2 : tensor<*xf32> +} + +// CHECK-SCF-LABEL: @relu_grad +// CHECK-SCF-SAME: (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) +// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[S:.*]] = shape.any %[[S1]], %[[S0]] +// CHECK-SCF-DAG: %[[N:.*]] = shape.num_elements %[[S]] +// CHECK-SCF-DAG: %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] +// CHECK-SCF-DAG: %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) +// CHECK-SCF-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} +// CHECK-SCF-DAG: %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"} +// CHECK-SCF: %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]]) +// CHECK-SCF-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]] +// CHECK-SCF-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]] +// CHECK-SCF-DAG: %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]] +// CHECK-SCF-DAG: %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) +// CHECK-SCF: return %[[RES]]