Support complex types for Sinh.
Because mhlo::ConstantLike doesn't support complex types, we need to use GetScalarOfType and broadcast it to the needed shape. Disable the tf2xla fallback, now that MLIR fully supports Sinh. PiperOrigin-RevId: 378123151
This commit is contained in:
		
							parent
							
								
									c47869f931
								
							
						
					
					
						commit
						9a8c254526
					
				| 
						 | 
					@ -994,9 +994,22 @@ Value MaterializeSinhApproximationForLargeX(ConversionPatternRewriter &rewriter,
 | 
				
			||||||
                                            Location loc, ValueRange operands) {
 | 
					                                            Location loc, ValueRange operands) {
 | 
				
			||||||
  SinhOp::Adaptor transformed(operands);
 | 
					  SinhOp::Adaptor transformed(operands);
 | 
				
			||||||
  Value x = transformed.operand();
 | 
					  Value x = transformed.operand();
 | 
				
			||||||
 | 
					  auto result_ty = x.getType().cast<ShapedType>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  Value log_one_half =
 | 
					  // TODO(b/190374484): Use mhlo::ConstantLikeOp when it supports complex types.
 | 
				
			||||||
      rewriter.create<mhlo::LogOp>(loc, getConstantLike(rewriter, loc, 0.5, x));
 | 
					  Value two = rewriter.create<mhlo::ConstOp>(
 | 
				
			||||||
 | 
					      loc, hlo::GetScalarOfType(getElementTypeOrSelf(x.getType()), 2));
 | 
				
			||||||
 | 
					  Type extent_tensor_type = shape::getExtentTensorType(x.getContext());
 | 
				
			||||||
 | 
					  Value uncasted_shape =
 | 
				
			||||||
 | 
					      rewriter.create<shape::ShapeOfOp>(loc, extent_tensor_type, x);
 | 
				
			||||||
 | 
					  Type shape_ty =
 | 
				
			||||||
 | 
					      RankedTensorType::get({result_ty.getRank()}, rewriter.getIndexType());
 | 
				
			||||||
 | 
					  Value shape = rewriter.create<tensor::CastOp>(loc, shape_ty, uncasted_shape);
 | 
				
			||||||
 | 
					  Value two_with_x_shape = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
 | 
				
			||||||
 | 
					      loc, result_ty, two, shape, rewriter.getI64TensorAttr({}));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  Value log_two = rewriter.create<mhlo::LogOp>(loc, two_with_x_shape);
 | 
				
			||||||
 | 
					  Value log_one_half = rewriter.create<mhlo::NegOp>(loc, log_two);
 | 
				
			||||||
  Value exp_add = rewriter.create<mhlo::ExpOp>(
 | 
					  Value exp_add = rewriter.create<mhlo::ExpOp>(
 | 
				
			||||||
      loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
 | 
					      loc, rewriter.create<mhlo::AddOp>(loc, x, log_one_half));
 | 
				
			||||||
  Value exp_sub = rewriter.create<mhlo::ExpOp>(
 | 
					  Value exp_sub = rewriter.create<mhlo::ExpOp>(
 | 
				
			||||||
| 
						 | 
					@ -1039,10 +1052,9 @@ struct ConvertSinhOp : public OpConversionPattern<SinhOp> {
 | 
				
			||||||
    SinhOp::Adaptor transformed(operands);
 | 
					    SinhOp::Adaptor transformed(operands);
 | 
				
			||||||
    Value x = transformed.operand();
 | 
					    Value x = transformed.operand();
 | 
				
			||||||
    if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
 | 
					    if (x.getType().cast<ShapedType>().getElementType().isa<ComplexType>()) {
 | 
				
			||||||
      // TODO(hinsu): Support operands with complex element types by always
 | 
					      rewriter.replaceOp(op, MaterializeSinhApproximationForLargeX(
 | 
				
			||||||
      // using the formula for large x. The compare op is not legal for complex
 | 
					                                 rewriter, op.getLoc(), operands));
 | 
				
			||||||
      // numbers.
 | 
					      return success();
 | 
				
			||||||
      return failure();
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    rewriter.replaceOp(op,
 | 
					    rewriter.replaceOp(op,
 | 
				
			||||||
                       MaterializeWithUpcast(rewriter, op.getLoc(), operands,
 | 
					                       MaterializeWithUpcast(rewriter, op.getLoc(), operands,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -2129,8 +2129,12 @@ func @polygamma_f16(%lhs : tensor<f16>, %rhs : tensor<f16>) -> tensor<f16> {
 | 
				
			||||||
// CHECK-LABEL: @sinh_f32
 | 
					// CHECK-LABEL: @sinh_f32
 | 
				
			||||||
// CHECK-SAME: (%[[X:.*]]: tensor<f32>)
 | 
					// CHECK-SAME: (%[[X:.*]]: tensor<f32>)
 | 
				
			||||||
func @sinh_f32(%x : tensor<f32>) -> tensor<f32> {
 | 
					func @sinh_f32(%x : tensor<f32>) -> tensor<f32> {
 | 
				
			||||||
  // CHECK: %[[HALF:.*]] = mhlo.constant dense<5.000000e-01> : tensor<f32>
 | 
					  // CHECK: %[[TWO:.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
 | 
				
			||||||
  // CHECK: %[[LOG_HALF:.*]] = "mhlo.log"(%[[HALF]]) : (tensor<f32>) -> tensor<f32>
 | 
					  // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[X]] : tensor<f32> -> tensor<?xindex>
 | 
				
			||||||
 | 
					  // CHECK: %[[CASTED_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<?xindex> to tensor<0xindex>
 | 
				
			||||||
 | 
					  // CHECK: %[[BROADCASTED_TWO:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TWO]], %[[CASTED_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<0xindex>) -> tensor<f32>
 | 
				
			||||||
 | 
					  // CHECK: %[[LOG_TWO:.*]] = "mhlo.log"(%[[BROADCASTED_TWO]]) : (tensor<f32>) -> tensor<f32>
 | 
				
			||||||
 | 
					  // CHECK: %[[LOG_HALF:.*]] = "mhlo.negate"(%[[LOG_TWO]]) : (tensor<f32>) -> tensor<f32>
 | 
				
			||||||
  // CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor<f32>
 | 
					  // CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor<f32>
 | 
				
			||||||
  // CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor<f32>) -> tensor<f32>
 | 
					  // CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor<f32>) -> tensor<f32>
 | 
				
			||||||
  // CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor<f32>
 | 
					  // CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor<f32>
 | 
				
			||||||
| 
						 | 
					@ -2162,3 +2166,24 @@ func @sinh_f16(%x : tensor<f16>) -> tensor<f16> {
 | 
				
			||||||
  %1 = chlo.sinh %x : tensor<f16> -> tensor<f16>
 | 
					  %1 = chlo.sinh %x : tensor<f16> -> tensor<f16>
 | 
				
			||||||
  return %1 : tensor<f16>
 | 
					  return %1 : tensor<f16>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CHECK-LABEL: @sinh_complex
 | 
				
			||||||
 | 
					// CHECK-SAME: (%[[X:.*]]: tensor<2xcomplex<f32>>)
 | 
				
			||||||
 | 
					func @sinh_complex(%x : tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
 | 
				
			||||||
 | 
					  // CHECK: %[[TWO:.*]] = mhlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<complex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[SHAPE:.*]] = shape.shape_of %[[X]] : tensor<2xcomplex<f32>> -> tensor<?xindex>
 | 
				
			||||||
 | 
					  // CHECK: %[[CASTED_SHAPE:.*]] = tensor.cast %[[SHAPE]] : tensor<?xindex> to tensor<1xindex>
 | 
				
			||||||
 | 
					  // CHECK: %[[BROADCASTED_TWO:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[TWO]], %[[CASTED_SHAPE]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<complex<f32>>, tensor<1xindex>) -> tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[LOG_TWO:.*]] = "mhlo.log"(%[[BROADCASTED_TWO]]) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[LOG_HALF:.*]] = "mhlo.negate"(%[[LOG_TWO]]) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[X_PLUS_LOG_HALF:.*]] = mhlo.add %[[X]], %[[LOG_HALF]] : tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[EXP_1:.*]] = "mhlo.exponential"(%[[X_PLUS_LOG_HALF]]) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[LOG_HALF_MINUS_X:.*]] = mhlo.subtract %[[LOG_HALF]], %[[X]] : tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[EXP_2:.*]] = "mhlo.exponential"(%[[LOG_HALF_MINUS_X]]) : (tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: %[[RESULT:.*]] = mhlo.subtract %[[EXP_1]], %[[EXP_2]] : tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  // CHECK: return %[[RESULT]] : tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  %1 = chlo.sinh %x : tensor<2xcomplex<f32>> -> tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					  return %1 : tensor<2xcomplex<f32>>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue