[mhlo] Lower int->bool to a comparison with zero
This matches what TF (and C++) do in this case. PiperOrigin-RevId: 357566262
This commit is contained in:
parent
240a44de82
commit
b594254c79
|
@ -311,6 +311,26 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||||
// No conversion is needed for the same width floats
|
// No conversion is needed for the same width floats
|
||||||
return args.front();
|
return args.front();
|
||||||
}
|
}
|
||||||
|
if (targetType.isInteger(/*width=*/1)) {
|
||||||
|
// When casting to bool, we need to compare whether the value is equal to
|
||||||
|
// zero.
|
||||||
|
if (sourceType.isSignlessInteger()) {
|
||||||
|
Value zero_intval = b->create<::mlir::ConstantIntOp>(
|
||||||
|
loc, 0, sourceType.cast<IntegerType>().getWidth());
|
||||||
|
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||||
|
zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval);
|
||||||
|
}
|
||||||
|
return b->create<mlir::CmpIOp>(loc, CmpIPredicate::ne, args.front(),
|
||||||
|
zero_intval);
|
||||||
|
} else if (sourceType.isa<FloatType>()) {
|
||||||
|
Value zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
|
||||||
|
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
||||||
|
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
|
||||||
|
}
|
||||||
|
return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(),
|
||||||
|
zero);
|
||||||
|
}
|
||||||
|
}
|
||||||
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
|
if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) {
|
||||||
IntegerType src = sourceType.cast<IntegerType>();
|
IntegerType src = sourceType.cast<IntegerType>();
|
||||||
IntegerType res = targetType.cast<IntegerType>();
|
IntegerType res = targetType.cast<IntegerType>();
|
||||||
|
@ -327,13 +347,6 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||||
// No conversion is needed for the same width integers
|
// No conversion is needed for the same width integers
|
||||||
return args.front();
|
return args.front();
|
||||||
}
|
}
|
||||||
if (targetType.isInteger(/*width=*/1)) {
|
|
||||||
Value zero = b->create<ConstantOp>(loc, b->getFloatAttr(sourceType, 0.0));
|
|
||||||
if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) {
|
|
||||||
zero = b->create<::mlir::SplatOp>(loc, vec_type, zero);
|
|
||||||
}
|
|
||||||
return b->create<mlir::CmpFOp>(loc, CmpFPredicate::UNE, args.front(), zero);
|
|
||||||
}
|
|
||||||
if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
|
if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) {
|
||||||
return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
|
return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None);
|
||||||
}
|
}
|
||||||
|
|
|
@ -740,6 +740,20 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_i32_to_i1
|
||||||
|
func @convert_i32_to_i1(%input: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
||||||
|
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi1>
|
||||||
|
return %result : tensor<2x2xi1>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.init_tensor
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %{{.*}}: i1):
|
||||||
|
// CHECK-NEXT: %[[ZERO:.*]] = constant 0 : i32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = cmpi ne, %[[OPERAND_IN]], %[[ZERO]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert_f32_to_i1
|
// CHECK-LABEL: func @convert_f32_to_i1
|
||||||
func @convert_f32_to_i1(%input: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
func @convert_f32_to_i1(%input: tensor<2x2xf32>) -> tensor<2x2xi1> {
|
||||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||||
|
|
|
@ -514,6 +514,20 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_i32_to_i1
|
||||||
|
func @convert_i32_to_i1(%input: memref<2x2xi32>, %result: memref<2x2xi1>) {
|
||||||
|
"lmhlo.convert"(%input, %result)
|
||||||
|
: (memref<2x2xi32>, memref<2x2xi1>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i32, %[[RESULT_OUT:.*]]: i1):
|
||||||
|
// CHECK-NEXT: %[[ZERO:.*]] = constant 0 : i32
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = cmpi ne, %[[OPERAND_IN]], %[[ZERO]] : i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i1
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert_f32_to_i1
|
// CHECK-LABEL: func @convert_f32_to_i1
|
||||||
func @convert_f32_to_i1(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
func @convert_f32_to_i1(%input: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||||
"lmhlo.convert"(%input, %result)
|
"lmhlo.convert"(%input, %result)
|
||||||
|
|
Loading…
Reference in New Issue