Make lmhlo.sign to linalg lowering work for more floating point and integer types
PiperOrigin-RevId: 334400341
This commit is contained in:
parent
04bf09382e
commit
336ee14538
|
@ -469,11 +469,27 @@ inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||||
ArrayRef<Value> args,
|
ArrayRef<Value> args,
|
||||||
OpBuilder* b) {
|
OpBuilder* b) {
|
||||||
Type element_type = args.front().getType();
|
Type element_type = args.front().getType();
|
||||||
if (element_type.isa<FloatType>()) {
|
if (auto float_type = element_type.dyn_cast<FloatType>()) {
|
||||||
FloatType float_type = element_type.cast<FloatType>();
|
bool ignored;
|
||||||
APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0);
|
APFloat one_apfloat(1.0f);
|
||||||
Value one = b->create<mlir::ConstantFloatOp>(loc, const_value, float_type);
|
one_apfloat.convert(float_type.getFloatSemantics(),
|
||||||
|
APFloat::rmNearestTiesToEven, &ignored);
|
||||||
|
Value one = b->create<mlir::ConstantFloatOp>(loc, one_apfloat, float_type);
|
||||||
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]);
|
||||||
|
} else if (auto integer_type = element_type.dyn_cast<IntegerType>()) {
|
||||||
|
// sign(x) = x == 0 ? 0 : ((x s>> 31) | 1)
|
||||||
|
Value zero =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||||
|
Value cmp =
|
||||||
|
b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero);
|
||||||
|
Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>(
|
||||||
|
loc, integer_type.getWidth() - 1, integer_type.getWidth());
|
||||||
|
Value ashr =
|
||||||
|
b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one);
|
||||||
|
Value one =
|
||||||
|
b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth());
|
||||||
|
Value or_op = b->create<::mlir::OrOp>(loc, ashr, one);
|
||||||
|
return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op);
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -586,6 +586,37 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @sign_bf16
|
||||||
|
func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) {
|
||||||
|
"lmhlo.sign"(%input, %result) : (memref<2x2xbf16>, memref<2x2xbf16>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]):
|
||||||
|
// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @sign_i16
|
||||||
|
func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) {
|
||||||
|
"lmhlo.sign"(%input, %result) : (memref<2x2xi16>, memref<2x2xi16>) -> ()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// CHECK: linalg.generic
|
||||||
|
// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]):
|
||||||
|
// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16
|
||||||
|
// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16
|
||||||
|
// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16
|
||||||
|
// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16
|
||||||
|
// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16
|
||||||
|
// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16
|
||||||
|
// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16
|
||||||
|
// CHECK-NEXT: linalg.yield %[[RESULT]] : i16
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @sqrt
|
// CHECK-LABEL: func @sqrt
|
||||||
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||||
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||||
|
|
Loading…
Reference in New Issue