[mhlo] Lower float->bool to a comparison with zero
This matches what TF (and C++) do in this case. PiperOrigin-RevId: 357534118
This commit is contained in:
parent
824bc9c425
commit
3e80d91e73
|
@ -323,6 +323,10 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
|||
// No conversion is needed for the same width integers
|
||||
return args.front();
|
||||
}
|
||||
if (targetType.isInteger(/*width=*/1)) {
|
||||
auto zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
|
||||
return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(), zero);
|
||||
}
|
||||
if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
|
||||
return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
|
||||
}
|
||||
|
|
|
@ -727,6 +727,20 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_i1
|
||||
func @convert_f32_to_i1(%input: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %result : tensor<2x2xi1>
|
||||
}
|
||||
// CHECK: linalg.init_tensor
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %{{.*}}: i1):
|
||||
// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_i32
|
||||
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
|
|
|
@ -502,6 +502,20 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_i1
|
||||
func @convert_f32_to_i1(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
"lmhlo.convert"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]: i1):
|
||||
// CHECK-NEXT: %[[ZERO:.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = cmpf une, %[[OPERAND_IN]], %[[ZERO]] : f32
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_i32
|
||||
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
|
||||
"lmhlo.convert"(%input, %result)
|
||||
|
|
Loading…
Reference in New Issue