From ef8ccdaebc474c62bd7515acfe5e3d8992a585a2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 22 Jan 2021 03:02:13 -0800 Subject: [PATCH] [MLIR] Add mhlo.logistic lowering to linalg PiperOrigin-RevId: 353205440 --- include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 2 ++ .../Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h | 1 + .../mhlo/transforms/map_lmhlo_to_scalar_op.h | 13 +++++++++++++ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 1 + .../mhlo/transforms/legalize_to_linalg.cc | 5 ++++- tests/hlo-legalize-to-linalg.mlir | 16 ++++++++++++++++ 6 files changed, 37 insertions(+), 1 deletion(-) diff --git a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 4a46e4b..30a8222 100644 --- a/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -114,6 +114,8 @@ def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFinit def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp; +def LHLO_LogisticOp : LHLO_UnaryElementwiseOp<"logistic", LHLO_FpOrComplexBuffer>, BASE_HLO_LogisticOp; + def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp; def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 73383f0..19f1c72 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(ImagOp); MAP_HLO_TO_LHLO(IotaOp); MAP_HLO_TO_LHLO(IsFiniteOp); MAP_HLO_TO_LHLO(LogOp); +MAP_HLO_TO_LHLO(LogisticOp); MAP_HLO_TO_LHLO(Log1pOp); MAP_HLO_TO_LHLO(MaxOp); MAP_HLO_TO_LHLO(MinOp); diff --git a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h index d6174d9..9354830 100644 --- a/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h +++ b/include/mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h @@ -437,6 +437,19 @@ inline Value MapLhloOpToStdScalarOp(Location loc, loc, result_types, args, b); } +template <> +inline Value MapLhloOpToStdScalarOp( + Location loc, ArrayRef result_types, ArrayRef args, + OpBuilder* b) { + auto ty = result_types.front().cast(); + Value one = b->create(loc, b->getFloatAttr(ty, 1.0)); + Value x = args.front(); + Value neg_x = b->create(loc, x); + Value exp_neg_x = b->create<::mlir::ExpOp>(loc, neg_x); + Value one_add_exp_neg_x = b->create(loc, one, exp_neg_x); + return b->create(loc, one, one_add_exp_neg_x); +} + template <> inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 185c622..2b606cf 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -233,6 +233,7 @@ class HloToLhloDynamicBroadcastInDimOpConverter LogicalResult matchAndRewrite( mhlo::DynamicBroadcastInDimOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { + if (!op.getType().isa()) return failure(); Value result = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); if (insert_copy_) { diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 7f630a5..deb7654 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -474,7 +474,8 @@ class HloDynamicBroadcastInDimConverter Value index = rewriter.create(loc, i); dyn_dims.push_back(rewriter.create(loc, shape, index)); } - auto result_type = op.getType().cast(); + auto result_type = op.getType().dyn_cast(); + if (!result_type) return failure(); int64_t nloops = result_type.getRank(); Value init = rewriter.create( @@ -1240,6 +1241,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, @@ -1364,6 +1366,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, + PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index adf667c..f4e9287 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -176,6 +176,22 @@ func @float_log1p(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- +// CHECK-LABEL: func @float_logistic +func @float_logistic(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: linalg.generic + // CHECK: ^bb0(%[[ARG:.*]]: f32, %{{.*}}: f32): + // CHECK: %[[C1:.*]] = constant 1.{{.*}}e+00 + // CHECK: %[[NEG_ARG:.*]] = negf %[[ARG]] + // CHECK: %[[EXP_NEG_ARG:.*]] = exp %[[NEG_ARG]] + // CHECK: %[[ONE_ADD_EXP_NEG_ARG:.*]] = addf %[[C1]], %[[EXP_NEG_ARG]] + // CHECK: %[[RESULT:.*]] = divf %[[C1]], %[[ONE_ADD_EXP_NEG_ARG]] + // CHECK: linalg.yield %[[RESULT]] + %0 = "mhlo.logistic"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @float_ceil func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: linalg.generic