diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index b65f68f..20bb377 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -57,6 +57,7 @@ MAP_HLO_TO_LHLO(FloorOp); MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); +MAP_HLO_TO_LHLO(IsFiniteOp); MAP_HLO_TO_LHLO(LogOp); MAP_HLO_TO_LHLO(MaxOp); MAP_HLO_TO_LHLO(MinOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index f6076ef..d2fb048 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -345,6 +345,22 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + if (args[0].getType().isa()) { + auto pos_inf = APFloat::getInf( + args[0].getType().cast().getFloatSemantics()); + auto const_pos_inf = + b->create(loc, b->getFloatAttr(args[0].getType(), pos_inf)); + Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]); + return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x, + const_pos_inf); + } + return nullptr; +} + /// Implements the conversion of HLO op to scalar op (to use within region of a /// linalg.generic op) for compare-select style operations like min/max. template diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 4fd7b4c..2d36b1c 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -848,6 +848,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, ScalarPointwiseToStandardConverter, @@ -955,6 +956,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, ReshapeOpConverter, ReverseConverter, TransposeConverter>(context); diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 263ea1b..91490b4 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -252,6 +252,20 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { // ----- +// CHECK-LABEL: func @is_finte +func @is_finte(%input: tensor<2x2xf32>) -> tensor<2x2xi1> { + %0 = "mhlo.is_finite"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi1> + return %0 : tensor<2x2xi1> +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32 +// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32 +// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @select func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { diff --git a/tests/lhlo-legalize-to-linalg.mlir b/tests/lhlo-legalize-to-linalg.mlir index 7ef1f0b..debb035 100644 --- a/tests/lhlo-legalize-to-linalg.mlir +++ b/tests/lhlo-legalize-to-linalg.mlir @@ -125,6 +125,20 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { // ----- +// CHECK-LABEL: func @is_finte +func @is_finte(%input: memref<2x2xf32>, %result: memref<2x2xi1>) { + "lmhlo.is_finite"(%input, %result) : (memref<2x2xf32>, memref<2x2xi1>) -> () + return +} +// CHECK: linalg.generic +// CHECK-NEXT: ^bb0(%[[OPERAND_IN:.*]]: f32, %[[RESULT_OUT:.*]]): +// CHECK-NEXT: %[[POS_INF:.+]] = constant 0x7F800000 : f32 +// CHECK-NEXT: %[[ABS_X:.+]] = absf %[[OPERAND_IN]] : f32 +// CHECK-NEXT: %[[RESULT:.+]] = cmpf "one", %[[ABS_X]], %[[POS_INF]] : f32 +// CHECK-NEXT: linalg.yield %[[RESULT]] : i1 + +// ----- + // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {