diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index de2a99b..fc91789 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" @@ -31,6 +33,39 @@ namespace mlir { namespace chlo { namespace { +struct ConvertConstantLikeOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite( + ConstantLikeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto result_ty = op.getType().cast(); + + // Unranked uses are not supported. Consider `transform-unranked-hlo`. + if (!result_ty.hasRank()) return failure(); + + // Lower to MHLO constant if statically shaped. + if (result_ty.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(result_ty, op.value())); + return success(); + } + + // Lower to broadcasted constant. + ConstantLikeOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + Type extent_tensor_type = shape::getExtentTensorType(op.getContext()); + Value constant = rewriter.create(loc, op.value()); + Value uncasted_shape = rewriter.create( + loc, extent_tensor_type, transformed.operand()); + Type shape_ty = + RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); + Value shape = rewriter.create(loc, shape_ty, uncasted_shape); + rewriter.replaceOpWithNewOp( + op, result_ty, constant, shape, rewriter.getI64TensorAttr({})); + return success(); + } +}; + // Converts binary ops that statically are determined to not broadcast directly // to the corresponding mhlo non-broadcasting op. template @@ -505,6 +540,9 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, context, patterns); PopulateForBinaryOp( context, patterns); + + // Other patterns. + patterns->insert(context); } } // namespace chlo diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir new file mode 100644 index 0000000..371e730 --- /dev/null +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-hlo-opt --mhlo-test-chlo-legalize-to-hlo --split-input-file %s | FileCheck %s + +// Lower statically shaped `constant_like` to constant. +// CHECK-LABEL: @constant_like_static_shape +func @constant_like_static_shape(%arg : tensor<1x2xi64>) -> tensor<1x2xf32> { + // CHECK: %[[RESULT:.*]] = mhlo.constant dense<3.200000e+00> : tensor<1x2xf32> + // CHECK: return %[[RESULT]] + %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } + : (tensor<1x2xi64>) -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// Lower dynamically shaped `constant_like` to broadcasted constant. +// CHECK-LABEL: constant_like_dynamic_shape +// CHECK-SAME: (%[[ARG:.*]]: tensor) +func @constant_like_dynamic_shape(%arg : tensor) -> tensor { + // CHECK: %[[CONSTANT:.*]] = mhlo.constant dense<3.200000e+00> : tensor + // CHECK: %[[UNCASTED_SHAPE:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor + // CHECK: %[[SHAPE:.*]] = tensor_cast %[[UNCASTED_SHAPE]] : tensor to tensor<2xindex> + // CHECK: %[[BROADCASTED_CONSTANT:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[CONSTANT]], %[[SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<2xindex>) -> tensor + // CHECK: return %[[BROADCASTED_CONSTANT]] : tensor + %result = "chlo.constant_like"(%arg) { value = 3.2 : f32 } + : (tensor) -> tensor + return %result : tensor +} +