diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 19f1c72..252b15d 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -55,6 +55,7 @@ MAP_HLO_TO_LHLO(CustomCallOp); MAP_HLO_TO_LHLO(DivOp); MAP_HLO_TO_LHLO(DotOp); MAP_HLO_TO_LHLO(ExpOp); +MAP_HLO_TO_LHLO(Expm1Op); MAP_HLO_TO_LHLO(FloorOp); MAP_HLO_TO_LHLO(GatherOp); MAP_HLO_TO_LHLO(ImagOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index d726c7f..5a957e7 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -251,6 +251,15 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 7ec6629..b562e3b 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -658,6 +658,7 @@ void populateHLOToLHLOConversionPattern(MLIRContext* context, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 6cbb68c..8dff2f5 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1398,6 +1398,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -1524,6 +1525,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index a25ee89..5c49724 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -87,6 +87,15 @@ func @exp(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @expm1 +func @expm1(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { + %result = "mhlo.exponential_minus_one"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: "lmhlo.exponential_minus_one"(%{{.*}}, %{{.*}}) + return %result : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @log func @log(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { %result = "mhlo.log"(%operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 545a781..99415e0 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -156,6 +156,16 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_expm1 +func @float_expm1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: expm1 + %0 = "mhlo.exponential_minus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_log func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic