From c36afd275e9132660a2281ccb19e6ca15142aa1d Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 9 Dec 2020 23:24:23 -0800 Subject: [PATCH] [HLO] Add a pattern for HLO ConstOp to HLO -> Linalg conversion. PiperOrigin-RevId: 346718273 --- .../mhlo/transforms/legalize_to_linalg.cc | 39 +++++++++++++------ tests/hlo-legalize-to-linalg.mlir | 11 ++++++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 69e729c..31b7a60 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -733,23 +733,37 @@ class IotaConverter : public OpConversionPattern { } }; -class ConstConverter : public OpConversionPattern { +template +class ConstConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - lmhlo::ConstOp const_op, ArrayRef args, + OpTy const_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { - auto loc = const_op.getLoc(); - auto value_attr = const_op.value().cast(); + Location loc = const_op.getLoc(); + auto value_attr = const_op.value().template cast(); if (value_attr.getType().getRank() != 0) return failure(); - auto std_const_op = - rewriter.create(loc, value_attr.getValue({})); - rewriter.create(loc, std_const_op, - const_op.getOperand(), ValueRange()); - rewriter.eraseOp(const_op); + ReplaceConstOp(loc, const_op, value_attr, rewriter); return success(); } + + private: + void ReplaceConstOp(Location loc, mhlo::ConstOp op, + DenseElementsAttr value_attr, + ConversionPatternRewriter& rewriter) const { + Value std_tensor_const = rewriter.create(loc, value_attr); + rewriter.replaceOp(op, {std_tensor_const}); + } + void ReplaceConstOp(Location loc, lmhlo::ConstOp op, + DenseElementsAttr value_attr, + ConversionPatternRewriter& rewriter) const { + Value std_scalar_const = + rewriter.create(loc, value_attr.getValue({})); + rewriter.create(loc, std_scalar_const, op.getOperand(), + llvm::None); + rewriter.eraseOp(op); + } }; class ReduceConverter : public OpConversionPattern { @@ -908,7 +922,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert, - ConstConverter, + ConstConverter, ConvToLinalgConverter, IotaConverter, LhloBroadcastInDimConverter, @@ -1029,7 +1043,8 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { patterns ->insert, - HloBroadcastInDimConverter, IotaConverter, + ConstConverter, HloBroadcastInDimConverter, + IotaConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 53fd205..62f416f 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -696,3 +696,14 @@ func @shift_right_logical(%lhs: tensor<2x2xi32>, // CHECK-NEXT: ^bb0(%[[LHS:.*]]: i32, %[[RHS:.*]]: i32): // CHECK-NEXT: %[[RESULT:.*]] = shift_right_unsigned %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: linalg.yield %[[RESULT]] : i32 + +// ----- + +// CHECK-LABEL: func @constant +func @constant() { + %result = "mhlo.constant"() { + value = dense<10> : tensor + } : () -> (tensor) + return +} +// CHECK: %[[CONSTANT:.*]] = constant dense<10> : tensor