[MHLO:linalg] Be more aggressive about turning mhlo.const into std.constant
On tensors the only difference between these ops is that mhlo.const supports unsigned types. PiperOrigin-RevId: 377970948
This commit is contained in:
		
							parent
							
								
									25b93c8d66
								
							
						
					
					
						commit
						d1c60df2fe
					
				|  | @ -1077,36 +1077,42 @@ struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> { | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| template <typename OpTy> | class ConstConverterBuffer : public OpConversionPattern<lmhlo::ConstOp> { | ||||||
| class ConstConverter : public OpConversionPattern<OpTy> { |  | ||||||
|  public: |  public: | ||||||
|   using OpConversionPattern<OpTy>::OpConversionPattern; |   using OpConversionPattern::OpConversionPattern; | ||||||
| 
 | 
 | ||||||
|   LogicalResult matchAndRewrite( |   LogicalResult matchAndRewrite( | ||||||
|       OpTy const_op, ArrayRef<Value> /*args*/, |       lmhlo::ConstOp const_op, ArrayRef<Value> /*args*/, | ||||||
|       ConversionPatternRewriter& rewriter) const final { |       ConversionPatternRewriter& rewriter) const final { | ||||||
|     Location loc = const_op.getLoc(); |     Location loc = const_op.getLoc(); | ||||||
|     auto value_attr = const_op.value().template cast<DenseElementsAttr>(); |     auto value_attr = const_op.value().cast<DenseElementsAttr>(); | ||||||
|     if (value_attr.getType().getRank() != 0) return failure(); |     if (value_attr.getType().getRank() != 0) return failure(); | ||||||
|     ReplaceConstOp(loc, const_op, value_attr, rewriter); |  | ||||||
|     return success(); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|  private: |  | ||||||
|   void ReplaceConstOp(Location loc, mhlo::ConstOp op, |  | ||||||
|                       DenseElementsAttr value_attr, |  | ||||||
|                       ConversionPatternRewriter& rewriter) const { |  | ||||||
|     Value std_tensor_const = rewriter.create<mlir::ConstantOp>(loc, value_attr); |  | ||||||
|     rewriter.replaceOp(op, {std_tensor_const}); |  | ||||||
|   } |  | ||||||
|   void ReplaceConstOp(Location loc, lmhlo::ConstOp op, |  | ||||||
|                       DenseElementsAttr value_attr, |  | ||||||
|                       ConversionPatternRewriter& rewriter) const { |  | ||||||
|     Value std_scalar_const = |     Value std_scalar_const = | ||||||
|         rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({})); |         rewriter.create<mlir::ConstantOp>(loc, value_attr.getValue({})); | ||||||
|     rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, op.getOperand(), |     rewriter.create<mlir::AffineStoreOp>(loc, std_scalar_const, | ||||||
|                                          llvm::None); |                                          const_op.getOperand(), llvm::None); | ||||||
|     rewriter.eraseOp(op); |     rewriter.eraseOp(const_op); | ||||||
|  |     return success(); | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | class ConstConverterTensor : public OpConversionPattern<mhlo::ConstOp> { | ||||||
|  |  public: | ||||||
|  |   using OpConversionPattern::OpConversionPattern; | ||||||
|  | 
 | ||||||
|  |   LogicalResult matchAndRewrite( | ||||||
|  |       mhlo::ConstOp const_op, ArrayRef<Value> /*args*/, | ||||||
|  |       ConversionPatternRewriter& rewriter) const final { | ||||||
|  |     auto value_attr = const_op.value().cast<DenseElementsAttr>(); | ||||||
|  |     auto type = | ||||||
|  |         typeConverter->convertType(const_op.getType()).cast<ShapedType>(); | ||||||
|  |     if (type != const_op.getType()) { | ||||||
|  |       // Signedness conversion.
 | ||||||
|  |       value_attr = value_attr.mapValues(type.getElementType(), | ||||||
|  |                                         [](const APInt& i) { return i; }); | ||||||
|  |     } | ||||||
|  |     rewriter.replaceOpWithNewOp<ConstantOp>(const_op, type, value_attr); | ||||||
|  |     return success(); | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
|  | @ -2291,7 +2297,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context, | ||||||
|                                            OwningRewritePatternList* patterns) { |                                            OwningRewritePatternList* patterns) { | ||||||
|   // clang-format off
 |   // clang-format off
 | ||||||
|   patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>, |   patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>, | ||||||
|                    ConstConverter<lmhlo::ConstOp>, |                    ConstConverterBuffer, | ||||||
|                    ConvToLinalgConverter, |                    ConvToLinalgConverter, | ||||||
|                    IotaConverter<lmhlo::IotaOp>, |                    IotaConverter<lmhlo::IotaOp>, | ||||||
|                    LhloBroadcastInDimConverter, |                    LhloBroadcastInDimConverter, | ||||||
|  | @ -2490,7 +2496,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context, | ||||||
|   // clang-format off
 |   // clang-format off
 | ||||||
|   patterns->insert< |   patterns->insert< | ||||||
|       BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter, |       BroadcastConverter<mhlo::BroadcastOp, false>, ConcatenateConverter, | ||||||
|       ConstConverter<mhlo::ConstOp>, HloDynamicBroadcastInDimConverter, |       ConstConverterTensor, HloDynamicBroadcastInDimConverter, | ||||||
|       HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, |       HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp, false>, | ||||||
|       IotaConverter<mhlo::DynamicIotaOp, false>, |       IotaConverter<mhlo::DynamicIotaOp, false>, | ||||||
|       PointwiseToLinalgConverter<mhlo::AbsOp, false>, |       PointwiseToLinalgConverter<mhlo::AbsOp, false>, | ||||||
|  |  | ||||||
|  | @ -2698,3 +2698,26 @@ func @scatter_update_slice(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, | ||||||
| // CHECK:           linalg.yield %[[SELECT]] : i32 | // CHECK:           linalg.yield %[[SELECT]] : i32 | ||||||
| // CHECK:         } -> tensor<6x3xi32> | // CHECK:         } -> tensor<6x3xi32> | ||||||
| // CHECK:         return %[[RES]] : tensor<6x3xi32> | // CHECK:         return %[[RES]] : tensor<6x3xi32> | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @const() -> tensor<3xi32> { | ||||||
|  |   // CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32> | ||||||
|  |   %cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xi32> | ||||||
|  |   return %cst : tensor<3xi32> | ||||||
|  | } | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @const_unsigned() -> tensor<3xui32> { | ||||||
|  |   // CHECK: = constant dense<[1, 2, 3]> : tensor<3xi32> | ||||||
|  |   %cst = mhlo.constant dense<[1, 2, 3]> : tensor<3xui32> | ||||||
|  |   return %cst : tensor<3xui32> | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | // ----- | ||||||
|  | 
 | ||||||
|  | func @const_splat() -> tensor<3xi16> { | ||||||
|  |   // CHECK: = constant dense<1> : tensor<3xi16> | ||||||
|  |   %cst = mhlo.constant dense<1> : tensor<3xi16> | ||||||
|  |   return %cst : tensor<3xi16> | ||||||
|  | } | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue