diff --git a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index 53cad4c..71a5d25 100644 --- a/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -1077,36 +1077,42 @@ struct ConcatenateConverter : public OpConversionPattern { } }; -template -class ConstConverter : public OpConversionPattern { +class ConstConverterBuffer : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - OpTy const_op, ArrayRef /*args*/, + lmhlo::ConstOp const_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { Location loc = const_op.getLoc(); - auto value_attr = const_op.value().template cast(); + auto value_attr = const_op.value().cast(); if (value_attr.getType().getRank() != 0) return failure(); - ReplaceConstOp(loc, const_op, value_attr, rewriter); - return success(); - } - - private: - void ReplaceConstOp(Location loc, mhlo::ConstOp op, - DenseElementsAttr value_attr, - ConversionPatternRewriter& rewriter) const { - Value std_tensor_const = rewriter.create(loc, value_attr); - rewriter.replaceOp(op, {std_tensor_const}); - } - void ReplaceConstOp(Location loc, lmhlo::ConstOp op, - DenseElementsAttr value_attr, - ConversionPatternRewriter& rewriter) const { Value std_scalar_const = rewriter.create(loc, value_attr.getValue({})); - rewriter.create(loc, std_scalar_const, op.getOperand(), - llvm::None); - rewriter.eraseOp(op); + rewriter.create(loc, std_scalar_const, + const_op.getOperand(), llvm::None); + rewriter.eraseOp(const_op); + return success(); + } +}; + +class ConstConverterTensor : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mhlo::ConstOp const_op, ArrayRef /*args*/, + ConversionPatternRewriter& rewriter) const final { + auto value_attr = const_op.value().cast(); + auto type = + typeConverter->convertType(const_op.getType()).cast(); + if (type != const_op.getType()) { + // Signedness conversion. + value_attr = value_attr.mapValues(type.getElementType(), + [](const APInt& i) { return i; }); + } + rewriter.replaceOpWithNewOp(const_op, type, value_attr); + return success(); } }; @@ -2291,7 +2297,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert, - ConstConverter, + ConstConverterBuffer, ConvToLinalgConverter, IotaConverter, LhloBroadcastInDimConverter, @@ -2490,7 +2496,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, // clang-format off patterns->insert< BroadcastConverter, ConcatenateConverter, - ConstConverter, HloDynamicBroadcastInDimConverter, + ConstConverterTensor, HloDynamicBroadcastInDimConverter, HloBroadcastInDimConverter, IotaConverter, IotaConverter, PointwiseToLinalgConverter, diff --git a/tests/hlo-legalize-to-linalg.mlir b/tests/hlo-legalize-to-linalg.mlir index 37f778a..fd528cc 100644 --- a/tests/hlo-legalize-to-linalg.mlir +++ b/tests/hlo-legalize-to-linalg.mlir @@ -2698,3 +2698,26 @@ func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, // CHECK: linalg.yield %[[SELECT]] : i32 // CHECK: } -> tensor<6x3xi32> // CHECK: return %[[RES]] : tensor<6x3xi32> + +// ----- + +func @const() -> tensor<3xi32> { + // CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32> + %cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xi32> + return %cst : tensor<3xi32> +} +// ----- + +func @const_unsigned() -> tensor<3xui32> { + // CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32> + %cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xui32> + return %cst : tensor<3xui32> +} + +// ----- + +func @const_splat() -> tensor<3xi16> { + // CHECK: = constant dense<1> : tensor<3xi16> + %cst = mhlo.constant dense<1> : tensor<3xi16> + return %cst : tensor<3xi16> +}