From 8672735e9a51166a42b898530ef4aa3611be3f59 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 15 Feb 2021 04:35:43 -0800 Subject: [PATCH] [mhlo] Lower float->bool to a comparison with zero This matches what TF (and C++) do in this case. PiperOrigin-RevId: 357553098 --- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 7 +++++++ tests/hlo-legalize-to-linalg.mlir | 14 ++++++++++++++ tests/lhlo-legalize-to-linalg.mlir | 14 ++++++++++++++ 3 files changed, 35 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 7aa956f..bcef846 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -323,6 +323,13 @@ inline Value MapLhloOpToStdScalarOp( // No conversion is needed for the same width integers return args.front(); } + if (targetType.isInteger(/*width=*/1)) { + Value zero = b->create(loc, b->getFloatAttr(sourceType, 0.0)); + if (VectorType vec_type = args.front().getType().dyn_cast()) { + zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); + } + return b->create(loc, CmpFPredicate::UNE, args.front(), zero); + } if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { return b->create(loc, result_types, args, mlir::None); } diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 59e7255..dda3ff5 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -727,6 +727,20 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @convert_f32_to_i1 +func @convert_f32_to_i1(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { + %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> + return %result : tensor<2x2xi1> +} +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i1): +// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index e0ab3e2..39a6d72 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -502,6 +502,20 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @convert_f32_to_i1 +func @convert_f32_to_i1(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { + "lmhlo.convert"(%input, %result) + : (memref<2x2xf32>, memref<2x2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i1): +// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { "lmhlo.convert"(%input, %result)