diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 8f1fa95..a311ae7 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -539,22 +539,14 @@ inline Value MapLhloOpToStdScalarOp(Location loc, Type element_type = getElementTypeOrSelf(args.front().getType()); if (auto float_type = element_type.dyn_cast()) { bool ignored; - APFloat zero_apfloat(0.0f); - zero_apfloat.convert(float_type.getFloatSemantics(), - APFloat::rmNearestTiesToEven, &ignored); - Value zero = - b->create(loc, zero_apfloat, float_type); + APFloat one_apfloat(1.0f); + one_apfloat.convert(float_type.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &ignored); + Value one = b->create(loc, one_apfloat, float_type); if (VectorType vec_type = args.front().getType().dyn_cast()) { - zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); + one = b->create<::mlir::SplatOp>(loc, vec_type, one); } - Value ne0_i1 = - b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero); - Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, float_type); - Value copy_sign = - b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]); - auto is_nan = - b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]); - return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign); + return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); } else if (auto integer_type = element_type.dyn_cast()) { // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) Value zero = diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 3dba087..5bfde29 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -594,12 +594,8 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : f32 -// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to f32 -// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : f32 -// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : f32 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : f32 +// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // ----- @@ -611,12 +607,8 @@ func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) { } // CHECK: linalg.generic // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]): -// CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : bf16 -// CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : bf16 -// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to bf16 -// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : bf16 -// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : bf16 -// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : bf16 +// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16 +// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16 // CHECK-NEXT: linalg.yield %[[RESULT]] : bf16 // -----