[mhlo] Lower int->int cast to sign extension instead of zero extension
Signless does not mean unsigned here. Currently mhlo only has signed types. PiperOrigin-RevId: 357561712
This commit is contained in:
parent
8672735e9a
commit
240a44de82
|
@ -316,9 +316,13 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||||
IntegerType res = targetType.cast<IntegerType>();
|
IntegerType res = targetType.cast<IntegerType>();
|
||||||
if (src.getWidth() > res.getWidth()) {
|
if (src.getWidth() > res.getWidth()) {
|
||||||
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
|
return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None);
|
||||||
} else if (src.getWidth() < res.getWidth()) {
|
} else if (src.getWidth() == 1) {
|
||||||
|
// Special case boolean values, so they get casted to `1` instead of `-1`.
|
||||||
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
return b->create<mlir::ZeroExtendIOp>(loc, result_types, args,
|
||||||
mlir::None);
|
mlir::None);
|
||||||
|
} else if (src.getWidth() < res.getWidth()) {
|
||||||
|
return b->create<mlir::SignExtendIOp>(loc, result_types, args,
|
||||||
|
mlir::None);
|
||||||
}
|
}
|
||||||
// No conversion is needed for the same width integers
|
// No conversion is needed for the same width integers
|
||||||
return args.front();
|
return args.front();
|
||||||
|
|
|
@ -662,6 +662,19 @@ func @convert_i1_to_f32(%input: tensor<2x2xi1>) -> tensor<2x2xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_i1_to_i32
|
||||||
|
func @convert_i1_to_i32(%input: tensor<2x2xi1>) -> tensor<2x2xi32> {
|
||||||
|
%result = "mhlo.convert"(%input) : (tensor<2x2xi1>) -> tensor<2x2xi32>
|
||||||
|
return %result : tensor<2x2xi32>
|
||||||
|
}
|
||||||
|
// CHECK: linalg.init_tensor
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %{{.*}}: i32):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i1 to i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert_i32_to_f32
|
// CHECK-LABEL: func @convert_i32_to_f32
|
||||||
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
||||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
||||||
|
@ -683,7 +696,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
|
||||||
// CHECK: linalg.init_tensor
|
// CHECK: linalg.init_tensor
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %{{.*}}: i32):
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %{{.*}}: i32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
|
// CHECK-NEXT: %[[RESULT:.*]] = sexti %[[OPERAND_IN]] : i16 to i32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
|
@ -419,6 +419,18 @@ func @convert_i1_to_f32(%input: memref<2x2xi1>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @convert_i1_to_i32
|
||||||
|
func @convert_i1_to_i32(%input: memref<2x2xi1>, %result: memref<2x2xi32>) {
|
||||||
|
"lmhlo.convert"(%input, %result) : (memref<2x2xi1>, memref<2x2xi32>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i1, %[[RESULT_OUT:.*]]: i32):
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i1 to i32
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @convert_i32_to_f32
|
// CHECK-LABEL: func @convert_i32_to_f32
|
||||||
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
|
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
|
||||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
|
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
|
||||||
|
@ -439,7 +451,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>,
|
||||||
}
|
}
|
||||||
// CHECK: linalg.generic
|
// CHECK: linalg.generic
|
||||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]: i32):
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]: i32):
|
||||||
// CHECK-NEXT: %[[RESULT:.*]] = zexti %[[OPERAND_IN]] : i16 to i32
|
// CHECK-NEXT: %[[RESULT:.*]] = sexti %[[OPERAND_IN]] : i16 to i32
|
||||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i32
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue