Fix SignOp lowering for floating point values.

It didn't return 0 for 0.0 and -0.0.
Currently we emit -0.0 for -0.0 which is correct according to the HLO dialect.
For the TF_SignOp we should emit 0.0 in that case, we will leave that as a TODO.
Enable the tests which work now, and add another one for Int64.
Also improve the registration code, we should not register the Int32 kernel.

PiperOrigin-RevId: 347981124
This commit is contained in:
Adrian Kuegel 2020-12-17 01:44:54 -08:00 committed by TensorFlow MLIR Team
parent 099c130daf
commit 1f244c3e2c
2 changed files with 26 additions and 10 deletions

View File

@ -548,14 +548,22 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
Type element_type = getElementTypeOrSelf(args.front().getType()); Type element_type = getElementTypeOrSelf(args.front().getType());
if (auto float_type = element_type.dyn_cast<FloatType>()) { if (auto float_type = element_type.dyn_cast<FloatType>()) {
bool ignored; bool ignored;
APFloat one_apfloat(1.0f); APFloat zero_apfloat(0.0f);
one_apfloat.convert(float_type.getFloatSemantics(), zero_apfloat.convert(float_type.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &ignored); APFloat::rmNearestTiesToEven, &ignored);
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type); Value zero =
b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type);
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
one = b->create<::mlir::SplatOp>(loc, vec_type, one); zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
} }
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); Value ne0_i1 =
b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero);
Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType());
Value copy_sign =
b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]);
auto is_nan =
b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]);
return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign);
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) { } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
Value zero = Value zero =

View File

@ -594,8 +594,12 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
} }
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : f32 // CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : f32
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : f32 // CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : f32
// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to f32
// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : f32
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : f32
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32
// ----- // -----
@ -607,8 +611,12 @@ func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) {
} }
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]): // CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]):
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16 // CHECK-NEXT: %[[CST_0:.*]] = constant 0.000000e+00 : bf16
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16 // CHECK-NEXT: %[[NE_0:.*]] = cmpf "one", %[[OPERAND_IN]], %[[CST_0]] : bf16
// CHECK-NEXT: %[[NE_0_FLOAT:.*]] = uitofp %[[NE_0]] : i1 to bf16
// CHECK-NEXT: %[[SIGN:.*]] = copysign %[[NE_0_FLOAT]], %[[OPERAND_IN]] : bf16
// CHECK-NEXT: %[[CMP:.*]] = cmpf "uno", %[[OPERAND_IN]], %[[OPERAND_IN]] : bf16
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[OPERAND_IN]], %[[SIGN]] : bf16
// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16 // CHECK-NEXT: linalg.yield %[[RESULT]] : bf16
// ----- // -----