From 9a8c254526412bb958e895d31c2da1e43fc7afd2 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 8 Jun 2021 04:22:15 -0700 Subject: [PATCH] Support complex types for Sinh. Because mhlo::ConstantLike doesn't support complex types, we need to use GetScalarOfType and broadcast it to the needed shape. Disable the tf2xla fallback, now that MLIR fully supports Sinh. PiperOrigin-RevId: 378123151 --- .../mhlo/transforms/chlo_legalize_to_hlo.cc | 24 +++++++++++---- tests/chlo_legalize_to_mhlo.mlir | 29 +++++++++++++++++-- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index c62cbbd..bd0ec23 100644 --- a/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -994,9 +994,22 @@ Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) { SinhOp::Adaptor transformed(operands); Value x = transformed.operand(); + auto result_ty = x.getType().cast(); - Value log_one_half = - rewriter.create(loc, getConstantLike(rewriter, loc, 0.5, x)); + // TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types. + Value two = rewriter.create( + loc, hlo::GetScalarOfType(getElementTypeOrSelf(x.getType()), 2)); + Type extent_tensor_type = shape::getExtentTensorType(x.getContext()); + Value uncasted_shape = + rewriter.create(loc, extent_tensor_type, x); + Type shape_ty = + RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType()); + Value shape = rewriter.create(loc, shape_ty, uncasted_shape); + Value two_with_x_shape = rewriter.create( + loc, result_ty, two, shape, rewriter.getI64TensorAttr({})); + + Value log_two = rewriter.create(loc, two_with_x_shape); + Value log_one_half = rewriter.create(loc, log_two); Value exp_add = rewriter.create( loc, rewriter.create(loc, x, log_one_half)); Value exp_sub = rewriter.create( @@ -1039,10 +1052,9 @@ struct ConvertSinhOp : public OpConversionPattern { SinhOp::Adaptor transformed(operands); Value x = transformed.operand(); if (x.getType().cast().getElementType().isa()) { - // TODO(hinsu): Support operands with complex element types by always - // using the formula for large x. The compare op is not legal for complex - // numbers. - return failure(); + rewriter.replaceOp(op, MaterializeSinhApproximationForLargeX( + rewriter, op.getLoc(), operands)); + return success(); } rewriter.replaceOp(op, MaterializeWithUpcast(rewriter, op.getLoc(), operands, diff --git a/tests/chlo_legalize_to_mhlo.mlir b/tests/chlo_legalize_to_mhlo.mlir index 50d985a..f94bb8d 100644 --- a/tests/chlo_legalize_to_mhlo.mlir +++ b/tests/chlo_legalize_to_mhlo.mlir @@ -2129,8 +2129,12 @@ func @polygamma_f16(%lhs : tensor, %rhs : tensor) -> tensor { // CHECK-LABEL: @sinh_f32 // CHECK-SAME: (%[[X:.*]]: tensor) func @sinh_f32(%x : tensor) -> tensor { - // CHECK: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor - // CHECK: %[[LOG_HALF:.*]] = "mhlo.log"(%[[HALF]]) : (tensor) -> tensor + // CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[X]] : tensor -> tensor + // CHECK: %[[CASTED_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor to tensor<0xindex> + // CHECK: %[[BROADCASTED_TWO:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TWO]], %[[CASTED_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<0xindex>) -> tensor + // CHECK: %[[LOG_TWO:.*]] = "mhlo.log"(%[[BROADCASTED_TWO]]) : (tensor) -> tensor + // CHECK: %[[LOG_HALF:.*]] = "mhlo.negate"(%[[LOG_TWO]]) : (tensor) -> tensor // CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor // CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor) -> tensor // CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor @@ -2162,3 +2166,24 @@ func @sinh_f16(%x : tensor) -> tensor { %1 = chlo.sinh %x : tensor -> tensor return %1 : tensor } + +// ---- + +// CHECK-LABEL: @sinh_complex +// CHECK-SAME: (%[[X:.*]]: tensor<2xcomplex>) +func @sinh_complex(%x : tensor<2xcomplex>) -> tensor<2xcomplex> { + // CHECK: %[[TWO:.*]] = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> + // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[X]] : tensor<2xcomplex> -> tensor + // CHECK: %[[CASTED_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor to tensor<1xindex> + // CHECK: %[[BROADCASTED_TWO:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TWO]], %[[CASTED_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor>, tensor<1xindex>) -> tensor<2xcomplex> + // CHECK: %[[LOG_TWO:.*]] = "mhlo.log"(%[[BROADCASTED_TWO]]) : (tensor<2xcomplex>) -> tensor<2xcomplex> + // CHECK: %[[LOG_HALF:.*]] = "mhlo.negate"(%[[LOG_TWO]]) : (tensor<2xcomplex>) -> tensor<2xcomplex> + // CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor<2xcomplex> + // CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor<2xcomplex>) -> tensor<2xcomplex> + // CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor<2xcomplex> + // CHECK: %[[EXP_2:.*]] = "mhlo.exponential"(%[[LOG_HALF_MINUS_X]]) : (tensor<2xcomplex>) -> tensor<2xcomplex> + // CHECK: %[[RESULT:.*]] = mhlo.subtract %[[EXP_1]], %[[EXP_2]] : tensor<2xcomplex> + // CHECK: return %[[RESULT]] : tensor<2xcomplex> + %1 = chlo.sinh %x : tensor<2xcomplex> -> tensor<2xcomplex> + return %1 : tensor<2xcomplex> +}