From 181d2cad310c5e6838042bb060c0d67a08b6b8f5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 13 Jan 2021 05:36:19 -0800 Subject: [PATCH] [MLIR][KernelGen] Add `tf.Log1p` kernel and tests PiperOrigin-RevId: 351566460 --- .../Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h | 1 + .../Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h | 12 ++++++++++++ lib/Dialect/mhlo/transforms/legalize_to_linalg.cc | 2 ++ 3 files changed, 15 insertions(+) diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index ef36f41..b2750e7 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -60,6 +60,7 @@ MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); MAP_HLO_TO_LHLO(IsFiniteOp); MAP_HLO_TO_LHLO(LogOp); +MAP_HLO_TO_LHLO(Log1pOp); MAP_HLO_TO_LHLO(MaxOp); MAP_HLO_TO_LHLO(MinOp); MAP_HLO_TO_LHLO(MulOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index eadc32c..d91770e 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -429,6 +429,18 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { + auto ty = result_types.front().cast(); + Value x = args.front(); + Value one = b->create(loc, b->getFloatAttr(ty, 1.0)); + Value x_plus_one = b->create(loc, x, one); + return b->create<::mlir::LogOp>(loc, x_plus_one); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 9a31f8c..e677b04 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1239,6 +1239,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -1359,6 +1360,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter,