From 89d81adf6d2e95cc13c18aba98b0edd3322cac15 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 15 Feb 2021 03:11:12 -0800 Subject: [PATCH] [mhlo] Lower float->bool to a comparison with zero This matches what TF (and C++) do in this case. PiperOrigin-RevId: 357541594 --- .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 4 ---- tests/hlo-legalize-to-linalg.mlir | 14 -------------- tests/lhlo-legalize-to-linalg.mlir | 14 -------------- 3 files changed, 32 deletions(-) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index c9a12ab..7aa956f 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -323,10 +323,6 @@ inline Value MapLhloOpToStdScalarOp( // No conversion is needed for the same width integers return args.front(); } - if (targetType.isInteger(/*width=*/1)) { - auto zero = b->create(loc, b->getFloatAttr(sourceType, 0.0)); - return b->create(loc, CmpFPredicate::UNE, args.front(), zero); - } if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { return b->create(loc, result_types, args, mlir::None); } diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index dda3ff5..59e7255 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -727,20 +727,6 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { // ----- -// CHECK-LABEL: func @convert_f32_to_i1 -func @convert_f32_to_i1(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { - %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> - return %result : tensor<2x2xi1> -} -// CHECK: linalg.init_tensor -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i1): -// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// ----- - // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 39a6d72..e0ab3e2 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -502,20 +502,6 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- -// CHECK-LABEL: func @convert_f32_to_i1 -func @convert_f32_to_i1(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { - "lmhlo.convert"(%input, %result) - : (memref<2x2xf32>, memref<2x2xi1>) -> () - return -} -// CHECK: linalg.generic -// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i1): -// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32 -// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32 -// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 - -// ----- - // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { "lmhlo.convert"(%input, %result)