diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index 77600e0..d6174d9 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -442,11 +442,8 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { - auto ty = result_types.front().cast(); - Value x = args.front(); - Value one = b->create(loc, b->getFloatAttr(ty, 1.0)); - Value x_plus_one = b->create(loc, x, one); - return b->create<::mlir::LogOp>(loc, x_plus_one); + return MapLhloOpToStdScalarOpImpl{}( + loc, result_types, args, b); } template <> diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index ade42d8..adf667c 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -166,6 +166,16 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_log1p +func @float_log1p(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: log1p + %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_ceil func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic