Make lmhlo.sign to linalg lowering work for more floating point and integer types
PiperOrigin-RevId: 334400341
This commit is contained in:
parent
04bf09382e
commit
336ee14538
|
@ -469,11 +469,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
|||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
FloatType float_type = element_type.cast<FloatType>();
|
||||
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
||||
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
|
||||
if (auto float_type = element_type.dyn_cast<FloatType>()) {
|
||||
bool ignored;
|
||||
APFloat one_apfloat(1.0f);
|
||||
one_apfloat.convert(float_type.getFloatSemantics(),
|
||||
APFloat::rmNearestTiesToEven, &ignored);
|
||||
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
|
||||
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
|
||||
Value zero =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
Value cmp =
|
||||
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
|
||||
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
|
||||
loc, integer_type.getWidth() - 1, integer_type.getWidth());
|
||||
Value ashr =
|
||||
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
|
||||
Value one =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
|
||||
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
|
||||
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -586,6 +586,37 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sign_bf16
|
||||
func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) {
|
||||
"lmhlo.sign"(%input, %result) : (memref<2x2xbf16>, memref<2x2xbf16>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sign_i16
|
||||
func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
|
||||
"lmhlo.sign"(%input, %result) : (memref<2x2xi16>, memref<2x2xi16>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]):
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16
|
||||
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
|
||||
// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16
|
||||
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
|
||||
// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16
|
||||
// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16
|
||||
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sqrt
|
||||
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue