[MLIR][KernelGen] Better rank specialization for clusters containing `mhlo.select`
Merge `mhlo.select` into rank specialization clusters. Infer shape equalities correctly from `mhlo.select` (and also from `mhlo.clamp`). This allows to lower the relu kernel completely flattened. PiperOrigin-RevId: 379925793
This commit is contained in:
		
							parent
							
								
									376da8592f
								
							
						
					
					
						commit
						9f47ff607b
					
				|  | @ -68,8 +68,7 @@ bool IsClusterable(Operation *op) { | ||||||
|   if (op->getNumOperands() == 0) return false; |   if (op->getNumOperands() == 0) return false; | ||||||
|   return (op->hasTrait<mlir::OpTrait::Elementwise>() && |   return (op->hasTrait<mlir::OpTrait::Elementwise>() && | ||||||
|           op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) || |           op->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>()) || | ||||||
|          (op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>() && |          op->hasTrait<mhlo::OpTrait::BroadcastingElementwise>(); | ||||||
|           op->hasTrait<chlo::OpTrait::Broadcasting>()); |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct RankSpecializationClusterPattern : public RewritePattern { | struct RankSpecializationClusterPattern : public RewritePattern { | ||||||
|  | @ -714,6 +713,8 @@ Value MaterializeDefaultRankSpecialization(OpBuilder &b, Location loc, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // This is a very limited form of shape inference. It is correct but incomplete.
 | // This is a very limited form of shape inference. It is correct but incomplete.
 | ||||||
|  | // TODO(frgossen): Infer shape equalities from surrounding shape constraints
 | ||||||
|  | // when these are generated.
 | ||||||
| SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences( | SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences( | ||||||
|     chlo::RankSpecializationClusterOp op) { |     chlo::RankSpecializationClusterOp op) { | ||||||
|   llvm::EquivalenceClasses<Value> eqs; |   llvm::EquivalenceClasses<Value> eqs; | ||||||
|  | @ -735,12 +736,17 @@ SmallVector<SmallVector<Value, 4>, 4> FindNonScalarShapeEquivalences( | ||||||
|       if (!nested_op.getOperands().empty() && !nested_op.getResults().empty()) |       if (!nested_op.getOperands().empty() && !nested_op.getResults().empty()) | ||||||
|         eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0)); |         eqs.unionSets(nested_op.getResult(0), nested_op.getOperand(0)); | ||||||
|     } |     } | ||||||
|     // TODO(frgossen): Replace this with a check for the appropriate trait when
 |   } | ||||||
|     // that is available.
 | 
 | ||||||
|  |   // Find equalities through special knowledge of ops.
 | ||||||
|  |   for (Operation &nested_op : op.getBody()->without_terminator()) { | ||||||
|     if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) { |     if (auto select_op = llvm::dyn_cast<mhlo::SelectOp>(nested_op)) { | ||||||
|       union_sets( |       union_sets( | ||||||
|           {select_op.on_true(), select_op.on_false(), select_op.getResult()}); |           {select_op.on_true(), select_op.on_false(), select_op.getResult()}); | ||||||
|     } |     } | ||||||
|  |     if (auto clamp_op = llvm::dyn_cast<mhlo::ClampOp>(nested_op)) { | ||||||
|  |       union_sets({clamp_op.operand(), clamp_op.getResult()}); | ||||||
|  |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // Convert to a list-like equivalence class representation.
 |   // Convert to a list-like equivalence class representation.
 | ||||||
|  |  | ||||||
|  | @ -615,3 +615,40 @@ func @all_equal_shapes_inferrable(%arg0: tensor<*xf64>, %arg1 : tensor<*xf64>) | ||||||
| // CHECK-SCF-DAG:   %[[RES_S:.*]] = shape.broadcast %8, %9 | // CHECK-SCF-DAG:   %[[RES_S:.*]] = shape.broadcast %8, %9 | ||||||
| // CHECK-SCF-DAG:   %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]]) | // CHECK-SCF-DAG:   %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RES]], %[[RES_S]]) | ||||||
| // CHECK-SCF:       return %[[RES]] | // CHECK-SCF:       return %[[RES]] | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | // All shapes are equal, which is inferrable through the select op. | ||||||
|  | // CHECK-LABEL: @relu_grad | ||||||
|  | // CHECK-SAME:  (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) | ||||||
|  | func @relu_grad(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { | ||||||
|  |   // CHECK: %[[RES:.*]] = "chlo.rank_specialization_cluster"(%[[ARG1]], %[[ARG0]]) | ||||||
|  |   // CHECK: ^bb0(%[[ARG1_:.*]]: tensor<*xf32>, %[[ARG0_:.*]]: tensor<*xf32>) | ||||||
|  |   // CHECK:   %[[TMP0:.*]] = "chlo.constant_like"(%[[ARG0_]]) {value = 0.0{{.*}}e+00 : f32} | ||||||
|  |   // CHECK:   %[[TMP1:.*]] = "mhlo.compare"(%[[ARG0_]], %[[TMP0]]) {comparison_direction = "GT"} | ||||||
|  |   // CHECK:   %[[TMP2:.*]] = "mhlo.select"(%[[TMP1]], %[[ARG1_]], %[[TMP0]]) | ||||||
|  |   // CHECK:   "chlo.rank_specialization_cluster_yield"(%[[TMP2]]) | ||||||
|  |   // CHECK: return %[[RES]] | ||||||
|  |   %0 = "chlo.constant_like"(%arg0) {value = 0.000000e+00 : f32} : (tensor<*xf32>) -> tensor<*xf32> | ||||||
|  |   %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "GT"} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xi1> | ||||||
|  |   %2 = "mhlo.select"(%1, %arg1, %0) : (tensor<*xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> | ||||||
|  |   return %2 : tensor<*xf32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // CHECK-SCF-LABEL: @relu_grad | ||||||
|  | // CHECK-SCF-SAME:  (%[[ARG0:.*]]: tensor<*xf32>, %[[ARG1:.*]]: tensor<*xf32>) | ||||||
|  | // CHECK-SCF-DAG:   %[[S0:.*]] = shape.shape_of %[[ARG0]] | ||||||
|  | // CHECK-SCF-DAG:   %[[S1:.*]] = shape.shape_of %[[ARG1]] | ||||||
|  | // CHECK-SCF-DAG:   %[[S:.*]] = shape.any %[[S1]], %[[S0]] | ||||||
|  | // CHECK-SCF-DAG:   %[[N:.*]] = shape.num_elements %[[S]] | ||||||
|  | // CHECK-SCF-DAG:   %[[FLAT_SHAPE:.*]] = tensor.from_elements %[[N]] | ||||||
|  | // CHECK-SCF-DAG:   %[[FLAT0:.*]] = "mhlo.dynamic_reshape"(%[[ARG0]], %[[FLAT_SHAPE]]) | ||||||
|  | // CHECK-SCF-DAG:   %[[FLAT1:.*]] = "mhlo.dynamic_reshape"(%[[ARG1]], %[[FLAT_SHAPE]]) | ||||||
|  | // CHECK-SCF-DAG:   %[[ZERO:.*]] = "chlo.constant_like"(%[[FLAT0]]) {value = 0.0{{.*}}+00 : f32} | ||||||
|  | // CHECK-SCF-DAG:   %[[PRED:.*]] = "mhlo.compare"(%[[FLAT0]], %[[ZERO]]) {comparison_direction = "GT"} | ||||||
|  | // CHECK-SCF:       %[[UNSHAPED_RES:.*]] = "mhlo.select"(%[[PRED]], %[[FLAT1]], %[[ZERO]]) | ||||||
|  | // CHECK-SCF-DAG:   %[[S0:.*]] = shape.shape_of %[[ARG0]] | ||||||
|  | // CHECK-SCF-DAG:   %[[S1:.*]] = shape.shape_of %[[ARG1]] | ||||||
|  | // CHECK-SCF-DAG:   %[[RES_SHAPE:.*]] = shape.broadcast %[[S1]], %[[S0]] | ||||||
|  | // CHECK-SCF-DAG:   %[[RES:.*]] = "mhlo.dynamic_reshape"(%[[UNSHAPED_RES]], %[[RES_SHAPE]]) | ||||||
|  | // CHECK-SCF:       return %[[RES]] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue