[MLIR] Lower `chlo.constant_like` to MHLO
Lower `chlo.constant_like` to a constant and, if needed, a broadcast. PiperOrigin-RevId: 331964137
This commit is contained in:
		
							parent
							
								
									da43c8596b
								
							
						
					
					
						commit
						a6fdebdc6c
					
				|  | @ -13,6 +13,8 @@ See the License for the specific language governing permissions and | |||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include <numeric> | ||||
| 
 | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" | ||||
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" | ||||
|  | @ -31,6 +33,39 @@ namespace mlir { | |||
| namespace chlo { | ||||
| namespace { | ||||
| 
 | ||||
| struct ConvertConstantLikeOp : public OpConversionPattern<ConstantLikeOp> { | ||||
|   using OpConversionPattern<ConstantLikeOp>::OpConversionPattern; | ||||
|   LogicalResult matchAndRewrite( | ||||
|       ConstantLikeOp op, ArrayRef<Value> operands, | ||||
|       ConversionPatternRewriter &rewriter) const override { | ||||
|     auto result_ty = op.getType().cast<ShapedType>(); | ||||
| 
 | ||||
|     // Unranked uses are not supported.  Consider `transform-unranked-hlo`.
 | ||||
|     if (!result_ty.hasRank()) return failure(); | ||||
| 
 | ||||
|     // Lower to MHLO constant if statically shaped.
 | ||||
|     if (result_ty.hasStaticShape()) { | ||||
|       rewriter.replaceOpWithNewOp<mhlo::ConstOp>( | ||||
|           op, DenseElementsAttr::get(result_ty, op.value())); | ||||
|       return success(); | ||||
|     } | ||||
| 
 | ||||
|     // Lower to broadcasted constant.
 | ||||
|     ConstantLikeOp::Adaptor transformed(operands); | ||||
|     auto loc = op.getLoc(); | ||||
|     Type extent_tensor_type = shape::getExtentTensorType(op.getContext()); | ||||
|     Value constant = rewriter.create<mhlo::ConstOp>(loc, op.value()); | ||||
|     Value uncasted_shape = rewriter.create<shape::ShapeOfOp>( | ||||
|         loc, extent_tensor_type, transformed.operand()); | ||||
|     Type shape_ty = | ||||
|         RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); | ||||
|     Value shape = rewriter.create<TensorCastOp>(loc, shape_ty, uncasted_shape); | ||||
|     rewriter.replaceOpWithNewOp<mhlo::DynamicBroadcastInDimOp>( | ||||
|         op, result_ty, constant, shape, rewriter.getI64TensorAttr({})); | ||||
|     return success(); | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| // Converts binary ops that statically are determined to not broadcast directly
 | ||||
| // to the corresponding mhlo non-broadcasting op.
 | ||||
| template <typename ChloOpTy, typename HloOpTy, typename Adaptor> | ||||
|  | @ -505,6 +540,9 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, | |||
|       context, patterns); | ||||
|   PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>( | ||||
|       context, patterns); | ||||
| 
 | ||||
|   // Other patterns.
 | ||||
|   patterns->insert<ConvertConstantLikeOp>(context); | ||||
| } | ||||
| 
 | ||||
| }  // namespace chlo
 | ||||
|  |  | |||
|  | @ -0,0 +1,26 @@ | |||
| // RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s | ||||
| 
 | ||||
| // Lower statically shaped `constant_like` to constant. | ||||
| // CHECK-LABEL: @constant_like_static_shape | ||||
| func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> { | ||||
|   // CHECK: %[[RESULT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<1x2xf32> | ||||
|   // CHECK: return %[[RESULT]] | ||||
|   %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } | ||||
|       : (tensor<1x2xi64>) -> tensor<1x2xf32> | ||||
|   return %result : tensor<1x2xf32> | ||||
| } | ||||
| 
 | ||||
| // Lower dynamically shaped `constant_like` to broadcasted constant. | ||||
| // CHECK-LABEL: constant_like_dynamic_shape | ||||
| // CHECK-SAME: (%[[ARG:.*]]: tensor<?x?xi64>) | ||||
| func @constant_like_dynamic_shape(%arg : tensor<?x?xi64>) -> tensor<?x?xf32> { | ||||
|   // CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<f32> | ||||
|   // CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor<?x?xi64> -> tensor<?xindex> | ||||
|   // CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor<?xindex> to tensor<2xindex> | ||||
|   // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<2xindex>) -> tensor<?x?xf32> | ||||
|   // CHECK: return %[[BROADCASTED_CONSTANT]] : tensor<?x?xf32> | ||||
|   %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } | ||||
|       : (tensor<?x?xi64>) -> tensor<?x?xf32> | ||||
|   return %result : tensor<?x?xf32> | ||||
| } | ||||
| 
 | ||||
		Loading…
	
		Reference in New Issue