diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 3102082..f6076ef 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -469,11 +469,27 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef args, OpBuilder* b) { Type element_type = args.front().getType(); - if (element_type.isa()) { - FloatType float_type = element_type.cast(); - APFloat const_value = float_type.isF32() ? APFloat(1.0f) : APFloat(1.0); - Value one = b->create(loc, const_value, float_type); + if (auto float_type = element_type.dyn_cast()) { + bool ignored; + APFloat one_apfloat(1.0f); + one_apfloat.convert(float_type.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &ignored); + Value one = b->create(loc, one_apfloat, float_type); return b->create<::mlir::CopySignOp>(loc, result_types, one, args[0]); + } else if (auto integer_type = element_type.dyn_cast()) { + // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) + Value zero = + b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); + Value cmp = + b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); + Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( + loc, integer_type.getWidth() - 1, integer_type.getWidth()); + Value ashr = + b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); + Value one = + b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); + Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); + return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); } return nullptr; } diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 6d79d2f..7ef1f0b 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -586,6 +586,37 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // ----- +// CHECK-LABEL: func @sign_bf16 +func @sign_bf16(%input: memref<2x2xbf16>, %result: memref<2x2xbf16>) { + "lmhlo.sign"(%input, %result) : (memref<2x2xbf16>, memref<2x2xbf16>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: bf16, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[CST:.*]] = constant 1.000000e+00 : bf16 +// CHECK-NEXT: %[[RESULT:.*]] = copysign %[[CST]], %[[OPERAND_IN]] : bf16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : bf16 + +// ----- + +// CHECK-LABEL: func @sign_i16 +func @sign_i16(%input: memref<2x2xi16>, %result: memref<2x2xi16>) { + "lmhlo.sign"(%input, %result) : (memref<2x2xi16>, memref<2x2xi16>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: i16, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[C0:.*]] = constant 0 : i16 +// CHECK-NEXT: %[[CMP:.*]] = cmpi "eq", %[[OPERAND_IN]], %[[C0]] : i16 +// CHECK-NEXT: %[[C15:.*]] = constant 15 : i16 +// CHECK-NEXT: %[[ASHR:.*]] = shift_right_signed %[[OPERAND_IN]], %[[C15]] : i16 +// CHECK-NEXT: %[[C1:.*]] = constant 1 : i16 +// CHECK-NEXT: %[[OR:.*]] = or %[[ASHR]], %[[C1]] : i16 +// CHECK-NEXT: %[[RESULT:.*]] = select %[[CMP]], %[[C0]], %[[OR]] : i16 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i16 + +// ----- + // CHECK-LABEL: func @sqrt func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()