[HLO] Delete LHLO memref cast ops and migrate to STD ones.
PiperOrigin-RevId: 340663578
This commit is contained in:
		
							parent
							
								
									82031b356c
								
							
						
					
					
						commit
						3d930d08c2
					
				|  | @ -314,169 +314,6 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> { | ||||||
|   ); |   ); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===// |  | ||||||
| // StaticMemRefCastOp |  | ||||||
| //===----------------------------------------------------------------------===// |  | ||||||
| 
 |  | ||||||
| def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast", |  | ||||||
|     [NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> { |  | ||||||
|   let summary = [{ |  | ||||||
|     modifies the offset, sizes and strides of a statically shaped memref |  | ||||||
|   }]; |  | ||||||
|   let description = [{ |  | ||||||
|     Casts the statically shaped memref operand to a memref with optionally |  | ||||||
|     modified offsets, sizes and strides. |  | ||||||
| 
 |  | ||||||
|     Example: |  | ||||||
|     ```mlir |  | ||||||
|     %buf_transformed = |  | ||||||
|         lmhlo.static_memref_cast %buf |  | ||||||
|         : memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]> |  | ||||||
| 
 |  | ||||||
|     // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and |  | ||||||
|     // offset 2. |  | ||||||
|     ``` |  | ||||||
|   }]; |  | ||||||
| 
 |  | ||||||
|   let arguments = (ins Arg<LHLO_Buffer, "", []>:$operand); |  | ||||||
|   let results = (outs Res<LHLO_Buffer, "", []>:$result); |  | ||||||
| 
 |  | ||||||
|   let builders = [ |  | ||||||
|     OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand), |  | ||||||
|     [{ |  | ||||||
|       $_state.addOperands(operand); |  | ||||||
|       $_state.types.push_back(resultType); |  | ||||||
|     }]>]; |  | ||||||
| 
 |  | ||||||
|   let extraClassDeclaration = [{ |  | ||||||
|     MemRefType getType() { return getResult().getType().cast<MemRefType>(); } |  | ||||||
|   }]; |  | ||||||
| 
 |  | ||||||
|   let verifier = [{ return Verify(*this); }]; |  | ||||||
|   let assemblyFormat = [{ |  | ||||||
|     $operand attr-dict `:` type($operand) `->` type($result) |  | ||||||
|   }]; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //===----------------------------------------------------------------------===// |  | ||||||
| // DynamicMemRefCastOp |  | ||||||
| //===----------------------------------------------------------------------===// |  | ||||||
| 
 |  | ||||||
| def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast", |  | ||||||
|     [SameVariadicOperandSize, NoSideEffect, |  | ||||||
|      DeclareOpInterfaceMethods<ViewLikeOpInterface>]> { |  | ||||||
|   let summary = "dynamic memref cast operation"; |  | ||||||
|   let description = [{ |  | ||||||
|     Change sizes and strides of a memref using the values computed in runtime. |  | ||||||
| 
 |  | ||||||
|     Example: |  | ||||||
|     ```mlir |  | ||||||
|     %buf_transformed = |  | ||||||
|         lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y] |  | ||||||
|         : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]> |  | ||||||
|     // The result of the op is a type-erased memref with `[%size_X, %size_Y]` |  | ||||||
|     // shape and `[%step_X, %step_Y]` strides. The offset will be inherited |  | ||||||
|     // from the input. |  | ||||||
|     ``` |  | ||||||
|   }]; |  | ||||||
| 
 |  | ||||||
|   let arguments = (ins |  | ||||||
|     Arg<LHLO_Buffer, "", []>:$operand, |  | ||||||
|     Variadic<Index>:$sizes, |  | ||||||
|     Variadic<Index>:$strides |  | ||||||
|   ); |  | ||||||
|   let results = (outs Res<LHLO_Buffer, "", []>:$result); |  | ||||||
| 
 |  | ||||||
|   let builders = [ |  | ||||||
|     OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$operand, |  | ||||||
|       "ValueRange":$sizes, "ValueRange":$strides), |  | ||||||
|     [{ |  | ||||||
|       $_state.addOperands(operand); |  | ||||||
|       $_state.addOperands(sizes); |  | ||||||
|       $_state.addOperands(strides); |  | ||||||
|       $_state.types.push_back(resultType); |  | ||||||
|      }]>]; |  | ||||||
| 
 |  | ||||||
|   let extraClassDeclaration = [{ |  | ||||||
|     MemRefType getType() { return getResult().getType().cast<MemRefType>(); } |  | ||||||
|   }]; |  | ||||||
| 
 |  | ||||||
|   let verifier = [{ return Verify(*this); }]; |  | ||||||
|   let assemblyFormat = [{ |  | ||||||
|     $operand `(` $sizes `)` `[` $strides `]` attr-dict `:` type($operand) `->` |  | ||||||
|     type($result) |  | ||||||
|   }]; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //===----------------------------------------------------------------------===// |  | ||||||
| // ReshapeMemRefCastOp |  | ||||||
| //===----------------------------------------------------------------------===// |  | ||||||
| 
 |  | ||||||
| def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [ |  | ||||||
|     DeclareOpInterfaceMethods<ViewLikeOpInterface>, |  | ||||||
|     NoSideEffect]>  { |  | ||||||
|   let summary = "reshape memref cast operation"; |  | ||||||
|   let description = [{ |  | ||||||
|     The `reshape_memref_cast` operation converts a memref from one type to an |  | ||||||
|     equivalent type with a provided shape. The data is never copied or moved. |  | ||||||
|     The source and destination types are compatible if both have the same |  | ||||||
|     element type, address space and identity layout map. The following |  | ||||||
|     combinations are possible: |  | ||||||
| 
 |  | ||||||
|     a. Both are ranked memref types. |  | ||||||
| 
 |  | ||||||
|     ```mlir |  | ||||||
|     // Reshape statically-shaped memref. |  | ||||||
|     %dst = reshape_memref_cast %src(%shape) |  | ||||||
|              : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> |  | ||||||
|     %dst0 = reshape_memref_cast %src(%shape0) |  | ||||||
|              : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> |  | ||||||
|     ``` |  | ||||||
| 
 |  | ||||||
|     b. Source type is ranked, destination type is unranked. |  | ||||||
| 
 |  | ||||||
|     ```mlir |  | ||||||
|     // Reshape dynamically-shaped 1D memref. |  | ||||||
|     %dst = reshape_memref_cast %src(%shape) |  | ||||||
|              : (memref<?xf32>, memref<?xi32>) to memref<*xf32> |  | ||||||
|     ``` |  | ||||||
| 
 |  | ||||||
|     c. Source type is unranked, destination type is ranked. |  | ||||||
| 
 |  | ||||||
|     ```mlir |  | ||||||
|     // Flatten unranked memref. |  | ||||||
|     %dst = reshape_memref_cast %src(%shape) |  | ||||||
|              : (memref<*xf32>, memref<1xi32>) to memref<?xf32> |  | ||||||
|     ``` |  | ||||||
| 
 |  | ||||||
|     d. Both are unranked memref types. |  | ||||||
| 
 |  | ||||||
|     ```mlir |  | ||||||
|     // Reshape unranked memref. |  | ||||||
|     %dst = reshape_memref_cast %src(%shape) |  | ||||||
|              : (memref<*xf32>, memref<?xi32>) to memref<*xf32> |  | ||||||
|     ``` |  | ||||||
|   }]; |  | ||||||
| 
 |  | ||||||
|   let arguments = (ins |  | ||||||
|     AnyRankedOrUnrankedMemRef:$operand, |  | ||||||
|     LHLO_ExtentBuffer:$shape |  | ||||||
|   ); |  | ||||||
|   let results = (outs AnyRankedOrUnrankedMemRef:$result); |  | ||||||
| 
 |  | ||||||
|   let extraClassDeclaration = [{ |  | ||||||
|     BaseMemRefType getType() { |  | ||||||
|         return getResult().getType().cast<BaseMemRefType>(); } |  | ||||||
|   }]; |  | ||||||
| 
 |  | ||||||
|   let verifier = [{ return Verify(*this); }]; |  | ||||||
|   let assemblyFormat = [{ |  | ||||||
|     $operand `(` $shape `)` attr-dict `:` `(` type($operand) `,` type($shape) |  | ||||||
|     `)` `->` type($result) |  | ||||||
|   }]; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
| // LMHLO Other op definitions. | // LMHLO Other op definitions. | ||||||
| //===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||||
|  |  | ||||||
|  | @ -46,12 +46,6 @@ def LhloLegalizeToGpuPass : Pass<"lhlo-legalize-to-gpu", "FuncOp"> { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def TestLhloToLLVMPass : Pass<"test-lhlo-legalize-to-llvm", "FuncOp"> { |  | ||||||
|   let summary = "Legalize from LHLO dialect to LLVM."; |  | ||||||
|   let constructor = "createTestLhloToLLVMPass()"; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { | def LhloLegalizeToParallelLoopsPass : Pass<"lhlo-legalize-to-parallel-loops", "FuncOp"> { | ||||||
|   let summary = "Legalize from LHLO dialect to parallel loops."; |   let summary = "Legalize from LHLO dialect to parallel loops."; | ||||||
|   let constructor = "createLegalizeLhloToParallelLoopsPass()"; |   let constructor = "createLegalizeLhloToParallelLoopsPass()"; | ||||||
|  |  | ||||||
|  | @ -35,8 +35,6 @@ inline void registerAllMhloPasses() { registerMHLOPasses(); } | ||||||
| 
 | 
 | ||||||
| namespace lmhlo { | namespace lmhlo { | ||||||
| 
 | 
 | ||||||
| std::unique_ptr<Pass> createTestLhloToLLVMPass(); |  | ||||||
| 
 |  | ||||||
| #define GEN_PASS_REGISTRATION | #define GEN_PASS_REGISTRATION | ||||||
| #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" | #include "mlir-hlo/Dialect/mhlo/transforms/lmhlo_passes.h.inc" | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -24,8 +24,6 @@ limitations under the License. | ||||||
| #include "mlir/Transforms/DialectConversion.h" | #include "mlir/Transforms/DialectConversion.h" | ||||||
| 
 | 
 | ||||||
| namespace mlir { | namespace mlir { | ||||||
| class LLVMTypeConverter; |  | ||||||
| class LowerToLLVMOptions; |  | ||||||
| class OwningRewritePatternList; | class OwningRewritePatternList; | ||||||
| 
 | 
 | ||||||
| // Populates a collection of rewrite patterns to realize element-wise operations
 | // Populates a collection of rewrite patterns to realize element-wise operations
 | ||||||
|  | @ -94,14 +92,6 @@ void PopulateTrigonometricToApproximationPatterns( | ||||||
| 
 | 
 | ||||||
| }  // namespace mhlo
 | }  // namespace mhlo
 | ||||||
| 
 | 
 | ||||||
| namespace lmhlo { |  | ||||||
| 
 |  | ||||||
| /// Collect a set of patterns to convert from the LHLO dialect to LLVM.
 |  | ||||||
| void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, |  | ||||||
|                                           OwningRewritePatternList *patterns); |  | ||||||
| 
 |  | ||||||
| }  // namespace lmhlo
 |  | ||||||
| 
 |  | ||||||
| namespace chlo { | namespace chlo { | ||||||
| 
 | 
 | ||||||
| // Populates a collection of conversion patterns for legalizing client-HLO to
 | // Populates a collection of conversion patterns for legalizing client-HLO to
 | ||||||
|  |  | ||||||
|  | @ -88,76 +88,6 @@ void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results, | ||||||
|   results.insert<EraseConstOp>(context); |   results.insert<EraseConstOp>(context); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| // StaticMemRefCastOp
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| 
 |  | ||||||
| Value StaticMemRefCastOp::getViewSource() { return *getODSOperands(0).begin(); } |  | ||||||
| 
 |  | ||||||
| static LogicalResult Verify(StaticMemRefCastOp op) { |  | ||||||
|   if (!op.operand().getType().cast<ShapedType>().hasStaticShape()) |  | ||||||
|     return op.emitOpError("operand must have static shape"); |  | ||||||
|   if (!op.getType().hasStaticShape()) |  | ||||||
|     return op.emitOpError("result must have static shape"); |  | ||||||
|   return success(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| // DynamicMemRefCastOp
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| 
 |  | ||||||
| Value DynamicMemRefCastOp::getViewSource() { |  | ||||||
|   return *getODSOperands(0).begin(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| static LogicalResult Verify(DynamicMemRefCastOp op) { |  | ||||||
|   // Check if `sizes` and `strides` args are compatible with the result type.
 |  | ||||||
|   if (op.sizes().size() != op.getType().getRank()) |  | ||||||
|     return op.emitOpError( |  | ||||||
|         "`sizes` args count must be equal to the rank of the output memref"); |  | ||||||
|   return success(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| // ReshapeMemrefCastOp
 |  | ||||||
| //===----------------------------------------------------------------------===//
 |  | ||||||
| 
 |  | ||||||
| Value ReshapeMemRefCastOp::getViewSource() { return operand(); } |  | ||||||
| 
 |  | ||||||
| static LogicalResult Verify(ReshapeMemRefCastOp op) { |  | ||||||
|   Type operandType = op.operand().getType(); |  | ||||||
|   Type resultType = op.result().getType(); |  | ||||||
| 
 |  | ||||||
|   Type operandElementType = operandType.cast<ShapedType>().getElementType(); |  | ||||||
|   Type resultElementType = resultType.cast<ShapedType>().getElementType(); |  | ||||||
|   if (operandElementType != resultElementType) |  | ||||||
|     return op.emitOpError( |  | ||||||
|         "element types of source and destination memref " |  | ||||||
|         "types should be the same"); |  | ||||||
| 
 |  | ||||||
|   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) |  | ||||||
|     if (!operandMemRefType.getAffineMaps().empty()) |  | ||||||
|       return op.emitOpError( |  | ||||||
|           "operand memref type should have identity affine map"); |  | ||||||
| 
 |  | ||||||
|   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); |  | ||||||
|   auto resultMemRefType = resultType.dyn_cast<MemRefType>(); |  | ||||||
|   if (resultMemRefType) { |  | ||||||
|     if (shapeSize == ShapedType::kDynamicSize) |  | ||||||
|       return op.emitOpError( |  | ||||||
|           "cannot use shape operand with dynamic length to " |  | ||||||
|           "cast statically-ranked memref type"); |  | ||||||
|     if (shapeSize != resultMemRefType.getRank()) |  | ||||||
|       return op.emitOpError( |  | ||||||
|           "length of shape operand differs from the result's memref rank"); |  | ||||||
| 
 |  | ||||||
|     if (!resultMemRefType.getAffineMaps().empty()) |  | ||||||
|       return op.emitOpError( |  | ||||||
|           "result memref type should have identity affine map"); |  | ||||||
|   } |  | ||||||
|   return success(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| }  // namespace lmhlo
 | }  // namespace lmhlo
 | ||||||
| }  // namespace mlir
 | }  // namespace mlir
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -134,8 +134,6 @@ add_mlir_library(LmhloPasses | ||||||
|   lhlo_fuse_linalg.cc |   lhlo_fuse_linalg.cc | ||||||
|   lhlo_legalize_to_affine.cc |   lhlo_legalize_to_affine.cc | ||||||
|   lhlo_legalize_to_gpu.cc |   lhlo_legalize_to_gpu.cc | ||||||
|   lhlo_legalize_to_llvm.cc |  | ||||||
|   lhlo_legalize_to_llvm_pass.cc |  | ||||||
|   lhlo_legalize_to_parallel_loops.cc |   lhlo_legalize_to_parallel_loops.cc | ||||||
| 
 | 
 | ||||||
|   DEPENDS |   DEPENDS | ||||||
|  |  | ||||||
|  | @ -206,7 +206,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter | ||||||
|   // Inserts dynamic memref to change the layout of the memref to put 0-stride
 |   // Inserts dynamic memref to change the layout of the memref to put 0-stride
 | ||||||
|   // and size of the target dimension if size-1 dimension expansion is
 |   // and size of the target dimension if size-1 dimension expansion is
 | ||||||
|   // necessary.
 |   // necessary.
 | ||||||
|   lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( |   MemRefReinterpretCastOp InsertDynamicMemrefCastOp( | ||||||
|       mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { |       mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { | ||||||
|     auto loc = op.getLoc(); |     auto loc = op.getLoc(); | ||||||
|     auto operand_type = operand.getType().cast<MemRefType>(); |     auto operand_type = operand.getType().cast<MemRefType>(); | ||||||
|  | @ -259,8 +259,13 @@ struct HloToLhloDynamicBroadcastInDimOpConverter | ||||||
|         makeStridedLinearLayoutMap(dynamic_layout, |         makeStridedLinearLayoutMap(dynamic_layout, | ||||||
|                                    /*offset=*/0, b->getContext())); |                                    /*offset=*/0, b->getContext())); | ||||||
| 
 | 
 | ||||||
|     auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>( |     SmallVector<int64_t, 2> static_sizes(sizes.size(), | ||||||
|         loc, type_erased_memref_type, operand, sizes, strides); |                                          ShapedType::kDynamicSize); | ||||||
|  |     SmallVector<int64_t, 2> static_strides(strides.size(), | ||||||
|  |                                            ShapedType::kDynamicStrideOrOffset); | ||||||
|  |     auto transformed_operand = b->create<MemRefReinterpretCastOp>( | ||||||
|  |         loc, type_erased_memref_type, operand, /*offset=*/0, static_sizes, | ||||||
|  |         static_strides, llvm::None, sizes, strides); | ||||||
|     return transformed_operand; |     return transformed_operand; | ||||||
|   } |   } | ||||||
| }; | }; | ||||||
|  | @ -284,7 +289,7 @@ struct HloToLhloDynamicReshapeConverter | ||||||
|       return failure(); |       return failure(); | ||||||
|     } |     } | ||||||
|     mhlo::DynamicReshapeOp::Adaptor adaptor(operands); |     mhlo::DynamicReshapeOp::Adaptor adaptor(operands); | ||||||
|     rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>( |     rewriter.replaceOpWithNewOp<MemRefReshapeOp>( | ||||||
|         op, result_type, adaptor.operand(), adaptor.output_shape()); |         op, result_type, adaptor.operand(), adaptor.output_shape()); | ||||||
|     return success(); |     return success(); | ||||||
|   } |   } | ||||||
|  |  | ||||||
|  | @ -1,370 +0,0 @@ | ||||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 |  | ||||||
| 
 |  | ||||||
| Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| you may not use this file except in compliance with the License. |  | ||||||
| You may obtain a copy of the License at |  | ||||||
| 
 |  | ||||||
|     http://www.apache.org/licenses/LICENSE-2.0
 |  | ||||||
| 
 |  | ||||||
| Unless required by applicable law or agreed to in writing, software |  | ||||||
| distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| See the License for the specific language governing permissions and |  | ||||||
| limitations under the License. |  | ||||||
| ==============================================================================*/ |  | ||||||
| 
 |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" |  | ||||||
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |  | ||||||
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |  | ||||||
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |  | ||||||
| #include "mlir/IR/StandardTypes.h" |  | ||||||
| #include "mlir/Transforms/DialectConversion.h" |  | ||||||
| 
 |  | ||||||
| namespace mlir { |  | ||||||
| namespace lmhlo { |  | ||||||
| namespace { |  | ||||||
| 
 |  | ||||||
| struct StaticMemRefCastOpConverter |  | ||||||
|     : public ConvertOpToLLVMPattern<StaticMemRefCastOp> { |  | ||||||
|   using ConvertOpToLLVMPattern<StaticMemRefCastOp>::ConvertOpToLLVMPattern; |  | ||||||
| 
 |  | ||||||
|   LogicalResult matchAndRewrite( |  | ||||||
|       Operation *op, ArrayRef<Value> operands, |  | ||||||
|       ConversionPatternRewriter &rewriter) const override { |  | ||||||
|     auto loc = op->getLoc(); |  | ||||||
|     auto cast_op = cast<StaticMemRefCastOp>(op); |  | ||||||
| 
 |  | ||||||
|     StaticMemRefCastOp::Adaptor operands_adaptor(operands); |  | ||||||
|     MemRefDescriptor sourceMemRef(operands_adaptor.operand()); |  | ||||||
| 
 |  | ||||||
|     MemRefType targetMemRefType = |  | ||||||
|         cast_op.getResult().getType().cast<MemRefType>(); |  | ||||||
|     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) |  | ||||||
|                                       .dyn_cast_or_null<LLVM::LLVMType>(); |  | ||||||
|     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) |  | ||||||
|       return failure(); |  | ||||||
|     // Create descriptor.
 |  | ||||||
|     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); |  | ||||||
|     Type llvmTargetElementTy = desc.getElementPtrType(); |  | ||||||
|     // Set allocated ptr.
 |  | ||||||
|     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); |  | ||||||
|     allocated = |  | ||||||
|         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); |  | ||||||
|     desc.setAllocatedPtr(rewriter, loc, allocated); |  | ||||||
|     // Set aligned ptr.
 |  | ||||||
|     Value ptr = sourceMemRef.alignedPtr(rewriter, loc); |  | ||||||
|     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); |  | ||||||
|     desc.setAlignedPtr(rewriter, loc, ptr); |  | ||||||
| 
 |  | ||||||
|     // Fill size and stride descriptors in memref.
 |  | ||||||
|     auto target_sizes = targetMemRefType.getShape(); |  | ||||||
|     int64_t target_offset; |  | ||||||
|     llvm::SmallVector<int64_t, 4> target_strides; |  | ||||||
|     if (failed((getStridesAndOffset(targetMemRefType, target_strides, |  | ||||||
|                                     target_offset)))) |  | ||||||
|       return failure(); |  | ||||||
| 
 |  | ||||||
|     // Copy offset of `targetMemRef`.
 |  | ||||||
|     desc.setConstantOffset(rewriter, loc, target_offset); |  | ||||||
|     for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { |  | ||||||
|       desc.setConstantSize(rewriter, loc, i, target_sizes[i]); |  | ||||||
|       desc.setConstantStride(rewriter, loc, i, target_strides[i]); |  | ||||||
|     } |  | ||||||
|     rewriter.replaceOp(op, {desc}); |  | ||||||
|     return success(); |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| struct DynamicMemRefCastOpConverter |  | ||||||
|     : public ConvertOpToLLVMPattern<DynamicMemRefCastOp> { |  | ||||||
|   using ConvertOpToLLVMPattern<DynamicMemRefCastOp>::ConvertOpToLLVMPattern; |  | ||||||
| 
 |  | ||||||
|   LogicalResult matchAndRewrite( |  | ||||||
|       Operation *op, ArrayRef<Value> operands, |  | ||||||
|       ConversionPatternRewriter &rewriter) const override { |  | ||||||
|     auto loc = op->getLoc(); |  | ||||||
|     auto cast_op = cast<DynamicMemRefCastOp>(op); |  | ||||||
| 
 |  | ||||||
|     DynamicMemRefCastOp::Adaptor operands_adaptor(operands); |  | ||||||
|     MemRefDescriptor sourceMemRef(operands_adaptor.operand()); |  | ||||||
| 
 |  | ||||||
|     MemRefType targetMemRefType = |  | ||||||
|         cast_op.getResult().getType().cast<MemRefType>(); |  | ||||||
|     auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) |  | ||||||
|                                       .dyn_cast_or_null<LLVM::LLVMType>(); |  | ||||||
|     if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) |  | ||||||
|       return failure(); |  | ||||||
|     // Create descriptor.
 |  | ||||||
|     auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); |  | ||||||
|     Type llvmTargetElementTy = desc.getElementPtrType(); |  | ||||||
|     // Set allocated ptr.
 |  | ||||||
|     Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); |  | ||||||
|     allocated = |  | ||||||
|         rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, allocated); |  | ||||||
|     desc.setAllocatedPtr(rewriter, loc, allocated); |  | ||||||
|     // Set aligned ptr.
 |  | ||||||
|     Value ptr = sourceMemRef.alignedPtr(rewriter, loc); |  | ||||||
|     ptr = rewriter.create<LLVM::BitcastOp>(loc, llvmTargetElementTy, ptr); |  | ||||||
|     desc.setAlignedPtr(rewriter, loc, ptr); |  | ||||||
|     // Copy offset of `sourceMemRef`.
 |  | ||||||
|     desc.setOffset(rewriter, loc, sourceMemRef.offset(rewriter, loc)); |  | ||||||
| 
 |  | ||||||
|     // Fill size and stride descriptors in memref.
 |  | ||||||
|     if (!cast_op.sizes().empty()) { |  | ||||||
|       auto sizes = operands_adaptor.sizes(); |  | ||||||
|       auto strides = operands_adaptor.strides(); |  | ||||||
|       for (int i = 0, e = targetMemRefType.getRank(); i < e; ++i) { |  | ||||||
|         desc.setSize(rewriter, loc, i, sizes[i]); |  | ||||||
|         desc.setStride(rewriter, loc, i, strides[i]); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|     rewriter.replaceOp(op, {desc}); |  | ||||||
|     return success(); |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| struct ReshapeMemRefCastOpConverter |  | ||||||
|     : public ConvertOpToLLVMPattern<ReshapeMemRefCastOp> { |  | ||||||
|   using ConvertOpToLLVMPattern<ReshapeMemRefCastOp>::ConvertOpToLLVMPattern; |  | ||||||
| 
 |  | ||||||
|   LogicalResult matchAndRewrite( |  | ||||||
|       Operation *op, ArrayRef<Value> operands, |  | ||||||
|       ConversionPatternRewriter &rewriter) const override { |  | ||||||
|     Location loc = op->getLoc(); |  | ||||||
| 
 |  | ||||||
|     auto reshape_op = cast<ReshapeMemRefCastOp>(op); |  | ||||||
|     auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>(); |  | ||||||
|     auto element_type = dst_type.getElementType(); |  | ||||||
| 
 |  | ||||||
|     auto shape = reshape_op.shape(); |  | ||||||
| 
 |  | ||||||
|     ReshapeMemRefCastOp::Adaptor operands_adaptor(operands); |  | ||||||
|     PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset( |  | ||||||
|         loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter); |  | ||||||
| 
 |  | ||||||
|     MemRefDescriptor shape_desc(operands_adaptor.shape()); |  | ||||||
| 
 |  | ||||||
|     auto shape_memref_type = shape.getType().cast<MemRefType>(); |  | ||||||
| 
 |  | ||||||
|     if (shape_memref_type.hasStaticShape()) { |  | ||||||
|       auto shape_length = shape_memref_type.getDimSize(0); |  | ||||||
| 
 |  | ||||||
|       MemRefType targetMemRefType = MemRefType::get( |  | ||||||
|           SmallVector<int64_t, 1>(shape_length, 1), element_type); |  | ||||||
|       auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) |  | ||||||
|                                         .dyn_cast_or_null<LLVM::LLVMType>(); |  | ||||||
|       if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) |  | ||||||
|         return failure(); |  | ||||||
|       // Create descriptor.
 |  | ||||||
|       auto desc = |  | ||||||
|           MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); |  | ||||||
|       desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr); |  | ||||||
|       desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr); |  | ||||||
|       desc.setOffset(rewriter, loc, ptrs_n_offset.offset); |  | ||||||
| 
 |  | ||||||
|       auto llvm_index_type = typeConverter.getIndexType(); |  | ||||||
|       auto llvm_index_ptr_type = llvm_index_type.getPointerTo(); |  | ||||||
|       Value stride_carried = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|           loc, llvm_index_type, |  | ||||||
|           rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); |  | ||||||
|       for (int i = shape_length - 1; i >= 0; --i) { |  | ||||||
|         Value pos = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|             loc, llvm_index_type, |  | ||||||
|             rewriter.getIntegerAttr(rewriter.getIndexType(), i)); |  | ||||||
|         Value ptr = rewriter.create<LLVM::GEPOp>( |  | ||||||
|             loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc), |  | ||||||
|             ValueRange{pos}); |  | ||||||
|         Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr); |  | ||||||
|         desc.setSize(rewriter, loc, i, extracted_size); |  | ||||||
|         desc.setStride(rewriter, loc, i, stride_carried); |  | ||||||
|         // Update stride
 |  | ||||||
|         if (i > 0) { |  | ||||||
|           stride_carried = |  | ||||||
|               rewriter.create<LLVM::MulOp>(loc, stride_carried, extracted_size); |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|       if (dst_type.isa<MemRefType>()) { |  | ||||||
|         rewriter.replaceOp(op, {desc}); |  | ||||||
|       } else { |  | ||||||
|         Value rank = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|             loc, llvm_index_type, |  | ||||||
|             rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length)); |  | ||||||
|         Value alloca = |  | ||||||
|             typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter); |  | ||||||
|         Value void_ptr = |  | ||||||
|             rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), alloca); |  | ||||||
|         auto unranked_desc = UnrankedMemRefDescriptor::pack( |  | ||||||
|             rewriter, loc, typeConverter, dst_type.cast<UnrankedMemRefType>(), |  | ||||||
|             {rank, void_ptr}); |  | ||||||
|         rewriter.replaceOp(op, {unranked_desc}); |  | ||||||
|       } |  | ||||||
|       return success(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // The shape is a rank-1 tensor with unknown length.
 |  | ||||||
|     Value result_rank = shape_desc.size(rewriter, loc, 0); |  | ||||||
|     // TODO(herhut): Propely handle address spaces.
 |  | ||||||
|     unsigned address_space = 0; |  | ||||||
|     auto target_type = |  | ||||||
|         typeConverter |  | ||||||
|             .convertType(UnrankedMemRefType::get(element_type, address_space)) |  | ||||||
|             .cast<LLVM::LLVMType>(); |  | ||||||
|     // Create the unranked memref descriptor that holds the ranked one. The
 |  | ||||||
|     // inner descriptor is allocated on stack.
 |  | ||||||
|     UnrankedMemRefDescriptor target_desc = |  | ||||||
|         UnrankedMemRefDescriptor::undef(rewriter, loc, target_type); |  | ||||||
|     target_desc.setRank(rewriter, loc, result_rank); |  | ||||||
|     SmallVector<Value, 1> sizes; |  | ||||||
|     UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, |  | ||||||
|                                            {target_desc}, sizes); |  | ||||||
|     auto void_ptr_type = LLVM::LLVMType::getInt8PtrTy(rewriter.getContext()); |  | ||||||
|     Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>( |  | ||||||
|         loc, void_ptr_type, sizes.front(), llvm::None); |  | ||||||
|     target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem); |  | ||||||
| 
 |  | ||||||
|     // Fill the fixed parts. For this, we cast to a 0-D memref.
 |  | ||||||
|     auto zero_d_memref_type = MemRefType::get({}, element_type); |  | ||||||
|     Value as_zero_d = rewriter.create<LLVM::BitcastOp>( |  | ||||||
|         loc, |  | ||||||
|         typeConverter.convertType(zero_d_memref_type) |  | ||||||
|             .cast<LLVM::LLVMType>() |  | ||||||
|             .getPointerTo(address_space), |  | ||||||
|         ranked_desc_mem); |  | ||||||
|     // Some common constants. Use 32 bit where required by gep struct indexes.
 |  | ||||||
|     auto int32_type = typeConverter.convertType(rewriter.getI32Type()); |  | ||||||
|     Value zero_index = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|         loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0)); |  | ||||||
|     Value zero = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|         loc, int32_type, rewriter.getI32IntegerAttr(0)); |  | ||||||
|     Value one = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|         loc, int32_type, rewriter.getI32IntegerAttr(1)); |  | ||||||
|     Value two = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|         loc, int32_type, rewriter.getI32IntegerAttr(2)); |  | ||||||
|     // Set base_pointer and aligned pointer.
 |  | ||||||
|     auto element_ptr_ptr_type = typeConverter.convertType(element_type) |  | ||||||
|                                     .cast<LLVM::LLVMType>() |  | ||||||
|                                     .getPointerTo(address_space) |  | ||||||
|                                     .getPointerTo(address_space); |  | ||||||
|     auto base_gep = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero})); |  | ||||||
|     rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep); |  | ||||||
|     auto aligned_gep = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one})); |  | ||||||
|     rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep); |  | ||||||
|     // Set offset.
 |  | ||||||
|     auto index_ptr_type = |  | ||||||
|         typeConverter.getIndexType().getPointerTo(address_space); |  | ||||||
|     auto offset_gep = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two})); |  | ||||||
|     rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep); |  | ||||||
| 
 |  | ||||||
|     // Use the offset pointer as base for further addressing. Copy over the
 |  | ||||||
|     // new shape and compute strides. For this, we need to create a loop from
 |  | ||||||
|     // rank - 1 to 0.
 |  | ||||||
|     Value one_index = rewriter.create<LLVM::ConstantOp>( |  | ||||||
|         loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1)); |  | ||||||
|     auto target_shape_base = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, index_ptr_type, offset_gep, ValueRange({one})); |  | ||||||
|     auto target_strides_base = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, index_ptr_type, target_shape_base, ValueRange({result_rank})); |  | ||||||
|     auto shape_ptr = shape_desc.alignedPtr(rewriter, loc); |  | ||||||
|     auto result_rank_minus_one = |  | ||||||
|         rewriter.create<LLVM::SubOp>(loc, result_rank, one_index); |  | ||||||
| 
 |  | ||||||
|     Block *init_block = rewriter.getInsertionBlock(); |  | ||||||
|     Block *cond_block = |  | ||||||
|         rewriter.splitBlock(init_block, rewriter.getInsertionPoint()); |  | ||||||
|     rewriter.setInsertionPointToEnd(init_block); |  | ||||||
|     rewriter.create<LLVM::BrOp>( |  | ||||||
|         loc, ValueRange({result_rank_minus_one, one_index}), cond_block); |  | ||||||
|     rewriter.setInsertionPointToStart(cond_block); |  | ||||||
|     auto index_arg = cond_block->addArgument(typeConverter.getIndexType()); |  | ||||||
|     auto stride_arg = cond_block->addArgument(typeConverter.getIndexType()); |  | ||||||
|     auto pred = rewriter.create<LLVM::ICmpOp>( |  | ||||||
|         loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), |  | ||||||
|         LLVM::ICmpPredicate::sge, index_arg, zero_index); |  | ||||||
| 
 |  | ||||||
|     Block *body_block = |  | ||||||
|         rewriter.splitBlock(cond_block, rewriter.getInsertionPoint()); |  | ||||||
|     rewriter.setInsertionPointToStart(body_block); |  | ||||||
| 
 |  | ||||||
|     // Copy size from shape to descriptor.
 |  | ||||||
|     auto size_load_gep = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, index_ptr_type, shape_ptr, ValueRange{index_arg}); |  | ||||||
|     auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep); |  | ||||||
|     auto size_store_gep = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, index_ptr_type, target_shape_base, ValueRange({index_arg})); |  | ||||||
|     rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep); |  | ||||||
|     // Write stride value and compute next one.
 |  | ||||||
|     auto stride_store_gep = rewriter.create<LLVM::GEPOp>( |  | ||||||
|         loc, index_ptr_type, target_strides_base, ValueRange({index_arg})); |  | ||||||
|     rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep); |  | ||||||
|     auto next_stride = |  | ||||||
|         rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size); |  | ||||||
| 
 |  | ||||||
|     // Decrement loop counter and branch back.
 |  | ||||||
|     auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index); |  | ||||||
|     rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}), |  | ||||||
|                                 cond_block); |  | ||||||
| 
 |  | ||||||
|     Block *remainder = |  | ||||||
|         rewriter.splitBlock(body_block, rewriter.getInsertionPoint()); |  | ||||||
| 
 |  | ||||||
|     // Hook up the cond exit to the remainder.
 |  | ||||||
|     rewriter.setInsertionPointToEnd(cond_block); |  | ||||||
|     rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(), |  | ||||||
|                                     remainder, ValueRange()); |  | ||||||
| 
 |  | ||||||
|     // Reset position to beginning of new remainder block.
 |  | ||||||
|     rewriter.setInsertionPointToStart(remainder); |  | ||||||
|     rewriter.replaceOp(op, {target_desc}); |  | ||||||
|     return success(); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|  private: |  | ||||||
|   struct PtrsAndOffset { |  | ||||||
|     Value allocated_ptr; |  | ||||||
|     Value aligned_ptr; |  | ||||||
|     Value offset; |  | ||||||
|   }; |  | ||||||
| 
 |  | ||||||
|   PtrsAndOffset ExtractMemRefPtrsAndOffset( |  | ||||||
|       Location loc, Value originalOperand, Value convertedOperand, |  | ||||||
|       ConversionPatternRewriter *rewriter) const { |  | ||||||
|     Type operandType = originalOperand.getType(); |  | ||||||
|     Value descriptor_ptr; |  | ||||||
|     if (operandType.isa<MemRefType>()) { |  | ||||||
|       descriptor_ptr = convertedOperand; |  | ||||||
|     } else { |  | ||||||
|       UnrankedMemRefDescriptor unranked_descriptor(convertedOperand); |  | ||||||
|       Value underlying_desc_ptr = |  | ||||||
|           unranked_descriptor.memRefDescPtr(*rewriter, loc); |  | ||||||
| 
 |  | ||||||
|       Type element_type = |  | ||||||
|           operandType.cast<UnrankedMemRefType>().getElementType(); |  | ||||||
|       LLVM::LLVMType memref_type_0d = |  | ||||||
|           typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type)) |  | ||||||
|               .cast<LLVM::LLVMType>(); |  | ||||||
|       descriptor_ptr = rewriter->create<LLVM::BitcastOp>( |  | ||||||
|           loc, memref_type_0d.getPointerTo(), underlying_desc_ptr); |  | ||||||
|       descriptor_ptr = rewriter->create<LLVM::LoadOp>(loc, descriptor_ptr); |  | ||||||
|     } |  | ||||||
|     MemRefDescriptor descriptor(descriptor_ptr); |  | ||||||
|     PtrsAndOffset result; |  | ||||||
|     result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc); |  | ||||||
|     result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc); |  | ||||||
|     result.offset = descriptor.offset(*rewriter, loc); |  | ||||||
|     return result; |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| }  // namespace
 |  | ||||||
| 
 |  | ||||||
| void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, |  | ||||||
|                                           OwningRewritePatternList *patterns) { |  | ||||||
|   patterns->insert<DynamicMemRefCastOpConverter, ReshapeMemRefCastOpConverter, |  | ||||||
|                    StaticMemRefCastOpConverter>(*converter); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| }  // namespace lmhlo
 |  | ||||||
| }  // namespace mlir
 |  | ||||||
|  | @ -1,63 +0,0 @@ | ||||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 |  | ||||||
| 
 |  | ||||||
| Licensed under the Apache License, Version 2.0 (the "License"); |  | ||||||
| you may not use this file except in compliance with the License. |  | ||||||
| You may obtain a copy of the License at |  | ||||||
| 
 |  | ||||||
|     http://www.apache.org/licenses/LICENSE-2.0
 |  | ||||||
| 
 |  | ||||||
| Unless required by applicable law or agreed to in writing, software |  | ||||||
| distributed under the License is distributed on an "AS IS" BASIS, |  | ||||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |  | ||||||
| See the License for the specific language governing permissions and |  | ||||||
| limitations under the License. |  | ||||||
| ==============================================================================*/ |  | ||||||
| 
 |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" |  | ||||||
| #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h" |  | ||||||
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |  | ||||||
| #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" |  | ||||||
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |  | ||||||
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |  | ||||||
| #include "mlir/IR/StandardTypes.h" |  | ||||||
| #include "mlir/Pass/Pass.h" |  | ||||||
| 
 |  | ||||||
| namespace mlir { |  | ||||||
| namespace lmhlo { |  | ||||||
| namespace { |  | ||||||
| 
 |  | ||||||
| class TestLhloToLLVMPass |  | ||||||
|     : public ::mlir::PassWrapper<TestLhloToLLVMPass, |  | ||||||
|                                  ::mlir::OperationPass<::mlir::ModuleOp>> { |  | ||||||
|   void getDependentDialects(DialectRegistry ®istry) const override { |  | ||||||
|     registry.insert<LLVM::LLVMDialect>(); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|  public: |  | ||||||
|   void runOnOperation() override { |  | ||||||
|     ModuleOp m = getOperation(); |  | ||||||
| 
 |  | ||||||
|     OwningRewritePatternList patterns; |  | ||||||
|     LLVMTypeConverter converter(&getContext()); |  | ||||||
|     populateStdToLLVMConversionPatterns(converter, patterns); |  | ||||||
|     PopulateLhloToLLVMConversionPatterns(&converter, &patterns); |  | ||||||
| 
 |  | ||||||
|     ConversionTarget target(getContext()); |  | ||||||
|     target.addLegalDialect<LLVM::LLVMDialect>(); |  | ||||||
|     target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); |  | ||||||
|     target.addIllegalDialect<LmhloDialect>(); |  | ||||||
| 
 |  | ||||||
|     if (failed(applyFullConversion(m, target, std::move(patterns)))) { |  | ||||||
|       signalPassFailure(); |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| }  // namespace
 |  | ||||||
| 
 |  | ||||||
| std::unique_ptr<Pass> createTestLhloToLLVMPass() { |  | ||||||
|   return std::make_unique<TestLhloToLLVMPass>(); |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| }  // namespace lmhlo
 |  | ||||||
| }  // namespace mlir
 |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s | // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -canonicalize -cse -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s | ||||||
| 
 | 
 | ||||||
| func @main() -> () { | func @main() -> () { | ||||||
|   call @trivial_broadcast_wrapper() : () -> () |   call @trivial_broadcast_wrapper() : () -> () | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ | ||||||
| // RUN: -buffer-deallocation -copy-removal -canonicalize -cse \ | // RUN: -buffer-deallocation -copy-removal -canonicalize -cse \ | ||||||
| // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ | // RUN: -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops \ | ||||||
| // RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \ | // RUN: -lower-affine -convert-scf-to-std -canonicalize -cse \ | ||||||
| // RUN: -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main \ | // RUN: -convert-std-to-llvm | mlir-cpu-runner -e main \ | ||||||
| // RUN: -entry-point-result=void \ | // RUN: -entry-point-result=void \ | ||||||
| // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ | // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | \ | ||||||
| // RUN: FileCheck %s | // RUN: FileCheck %s | ||||||
|  |  | ||||||
|  | @ -1,190 +0,0 @@ | ||||||
| // RUN: mlir-hlo-opt %s -chlo-legalize-to-hlo -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -copy-removal -canonicalize -cse -lhlo-legalize-to-linalg -lhlo-fuse-linalg -convert-linalg-to-loops -convert-scf-to-std -canonicalize -cse -test-lhlo-legalize-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s |  | ||||||
| 
 |  | ||||||
| func @main() -> () { |  | ||||||
|   call @reshape_with_static_shape_size_matrix_to_1D() : () -> () |  | ||||||
|   call @reshape_with_static_shape_size_matrix_to_3D() : () -> () |  | ||||||
|   call @reshape_with_dynamic_shape_size_matrix_to_1D() : () -> () |  | ||||||
|   call @reshape_with_dynamic_shape_size_matrix_to_3D() : () -> () |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } |  | ||||||
| 
 |  | ||||||
| func @reshape_with_static_shape_size_matrix_to_1D() { |  | ||||||
|   %c0 = constant 0 : index |  | ||||||
|   %c1 = constant 1 : index |  | ||||||
| 
 |  | ||||||
|   // Initialize input. |  | ||||||
|   %input = alloc() : memref<2x3xf32> |  | ||||||
|   %dim_x = dim %input, %c0 : memref<2x3xf32> |  | ||||||
|   %dim_y = dim %input, %c1 : memref<2x3xf32> |  | ||||||
|   scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { |  | ||||||
|     %i_i64 = index_cast %i : index to i64 |  | ||||||
|     %i_f32 = sitofp %i_i64 : i64 to f32 |  | ||||||
|     store %i_f32, %input[%i, %j] : memref<2x3xf32> |  | ||||||
|   } |  | ||||||
|   %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] |  | ||||||
|   // CHECK: [0,   0,   0] |  | ||||||
|   // CHECK: [1,   1,   1] |  | ||||||
| 
 |  | ||||||
|   // Initialize shape. |  | ||||||
|   %shape = alloc() : memref<1xi64> |  | ||||||
|   %num_elements = muli %dim_x, %dim_y : index |  | ||||||
|   %num_elements_i64 = index_cast %num_elements : index to i64 |  | ||||||
|   store %num_elements_i64, %shape[%c0] : memref<1xi64> |  | ||||||
| 
 |  | ||||||
|   // 1. Ranked input, ranked output. |  | ||||||
|   %output_1 = lmhlo.reshape_memref_cast %input(%shape) |  | ||||||
|                  : (memref<2x3xf32>, memref<1xi64>) -> memref<6xf32> |  | ||||||
|   %unranked_output_1 = memref_cast %output_1 : memref<6xf32> to memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_output_1) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] |  | ||||||
|   // CHECK: [0,  0,  0,  1,  1,  1] |  | ||||||
| 
 |  | ||||||
|   // 2. Ranked input, unranked output. |  | ||||||
|   %output_2 = lmhlo.reshape_memref_cast %input(%shape) |  | ||||||
|                  : (memref<2x3xf32>, memref<1xi64>) -> memref<*xf32> |  | ||||||
|   call @print_memref_f32(%output_2) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] |  | ||||||
|   // CHECK: [0,  0,  0,  1,  1,  1] |  | ||||||
| 
 |  | ||||||
|   // 3. Unranked input, ranked output. |  | ||||||
|   %output_3 = lmhlo.reshape_memref_cast %unranked_input(%shape) |  | ||||||
|                  : (memref<*xf32>, memref<1xi64>) -> memref<?xf32> |  | ||||||
|   %unranked_output_3 = memref_cast %output_3 : memref<?xf32> to memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_output_3) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] |  | ||||||
|   // CHECK: [0,  0,  0,  1,  1,  1] |  | ||||||
| 
 |  | ||||||
|   // 4. Unranked input, unranked output. |  | ||||||
|   %output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape) |  | ||||||
|                  : (memref<*xf32>, memref<1xi64>) -> memref<*xf32> |  | ||||||
|   call @print_memref_f32(%output_4) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] |  | ||||||
|   // CHECK: [0,  0,  0,  1,  1,  1] |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func @reshape_with_static_shape_size_matrix_to_3D() { |  | ||||||
|   %c0 = constant 0 : index |  | ||||||
|   %c1 = constant 1 : index |  | ||||||
|   %c2 = constant 2 : index |  | ||||||
| 
 |  | ||||||
|   // Initialize input. |  | ||||||
|   %input = alloc() : memref<2x3xf32> |  | ||||||
|   %dim_x = dim %input, %c0 : memref<2x3xf32> |  | ||||||
|   %dim_y = dim %input, %c1 : memref<2x3xf32> |  | ||||||
|   scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { |  | ||||||
|     %i_i64 = index_cast %i : index to i64 |  | ||||||
|     %i_f32 = sitofp %i_i64 : i64 to f32 |  | ||||||
|     store %i_f32, %input[%i, %j] : memref<2x3xf32> |  | ||||||
|   } |  | ||||||
|   %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] |  | ||||||
|   // CHECK: [0,   0,   0] |  | ||||||
|   // CHECK: [1,   1,   1] |  | ||||||
| 
 |  | ||||||
|   // Initialize shape. |  | ||||||
|   %shape = alloc() : memref<3xi64> |  | ||||||
|   %c1_i64 = constant 1 : i64 |  | ||||||
|   %c2_i64 = constant 2 : i64 |  | ||||||
|   %c3_i64 = constant 3 : i64 |  | ||||||
|   store %c3_i64, %shape[%c0] : memref<3xi64> |  | ||||||
|   store %c1_i64, %shape[%c1] : memref<3xi64> |  | ||||||
|   store %c2_i64, %shape[%c2] : memref<3xi64> |  | ||||||
| 
 |  | ||||||
|   // Static shape input and shape, dynamic output. |  | ||||||
|   %unranked_output = lmhlo.reshape_memref_cast %input(%shape) |  | ||||||
|                  : (memref<2x3xf32>, memref<3xi64>) -> memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1] |  | ||||||
|   // CHECK: {{\[}}{{\[}}[0,    0]], |  | ||||||
|   // CHECK:       {{\[}}[0,    1]], |  | ||||||
|   // CHECK:       {{\[}}[1,    1]]] |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func @reshape_with_dynamic_shape_size_matrix_to_1D() { |  | ||||||
|   %c0 = constant 0 : index |  | ||||||
|   %c1 = constant 1 : index |  | ||||||
| 
 |  | ||||||
|   // Initialize input. |  | ||||||
|   %input = alloc() : memref<2x3xf32> |  | ||||||
|   %dim_x = dim %input, %c0 : memref<2x3xf32> |  | ||||||
|   %dim_y = dim %input, %c1 : memref<2x3xf32> |  | ||||||
|   scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { |  | ||||||
|     %i_i64 = index_cast %i : index to i64 |  | ||||||
|     %i_f32 = sitofp %i_i64 : i64 to f32 |  | ||||||
|     store %i_f32, %input[%i, %j] : memref<2x3xf32> |  | ||||||
|   } |  | ||||||
|   %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] |  | ||||||
|   // CHECK: [0,   0,   0] |  | ||||||
|   // CHECK: [1,   1,   1] |  | ||||||
| 
 |  | ||||||
|   // Initialize shape. |  | ||||||
|   %shape = alloc(%c1) : memref<?xi64> |  | ||||||
|   %num_elements = muli %dim_x, %dim_y : index |  | ||||||
|   %num_elements_i64 = index_cast %num_elements : index to i64 |  | ||||||
|   store %num_elements_i64, %shape[%c0] : memref<?xi64> |  | ||||||
| 
 |  | ||||||
|   // 1. Ranked input, unranked output. |  | ||||||
|   %output_2 = lmhlo.reshape_memref_cast %input(%shape) |  | ||||||
|                  : (memref<2x3xf32>, memref<?xi64>) -> memref<*xf32> |  | ||||||
|   call @print_memref_f32(%output_2) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] |  | ||||||
|   // CHECK: [0,  0,  0,  1,  1,  1] |  | ||||||
| 
 |  | ||||||
|   // 2. Unranked input, unranked output. |  | ||||||
|   %output_4 = lmhlo.reshape_memref_cast %unranked_input(%shape) |  | ||||||
|                  : (memref<*xf32>, memref<?xi64>) -> memref<*xf32> |  | ||||||
|   call @print_memref_f32(%output_4) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 1 offset = 0 sizes = [6] strides = [1] |  | ||||||
|   // CHECK: [0,  0,  0,  1,  1,  1] |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| func @reshape_with_dynamic_shape_size_matrix_to_3D() { |  | ||||||
|   %c0 = constant 0 : index |  | ||||||
|   %c1 = constant 1 : index |  | ||||||
|   %c2 = constant 2 : index |  | ||||||
|   %c3 = constant 3 : index |  | ||||||
| 
 |  | ||||||
|   // Initialize input. |  | ||||||
|   %input = alloc() : memref<2x3xf32> |  | ||||||
|   %dim_x = dim %input, %c0 : memref<2x3xf32> |  | ||||||
|   %dim_y = dim %input, %c1 : memref<2x3xf32> |  | ||||||
|   scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { |  | ||||||
|     %i_i64 = index_cast %i : index to i64 |  | ||||||
|     %i_f32 = sitofp %i_i64 : i64 to f32 |  | ||||||
|     store %i_f32, %input[%i, %j] : memref<2x3xf32> |  | ||||||
|   } |  | ||||||
|   %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] |  | ||||||
|   // CHECK: [0,   0,   0] |  | ||||||
|   // CHECK: [1,   1,   1] |  | ||||||
| 
 |  | ||||||
|   // Initialize shape. |  | ||||||
|   %shape = alloc(%c3) : memref<?xi64> |  | ||||||
|   %c1_i64 = constant 1 : i64 |  | ||||||
|   %c2_i64 = constant 2 : i64 |  | ||||||
|   %c3_i64 = constant 3 : i64 |  | ||||||
|   store %c3_i64, %shape[%c0] : memref<?xi64> |  | ||||||
|   store %c1_i64, %shape[%c1] : memref<?xi64> |  | ||||||
|   store %c2_i64, %shape[%c2] : memref<?xi64> |  | ||||||
| 
 |  | ||||||
|   // Static shape input, dynamic output and shape. |  | ||||||
|   %unranked_output = lmhlo.reshape_memref_cast %input(%shape) |  | ||||||
|                  : (memref<2x3xf32>, memref<?xi64>) -> memref<*xf32> |  | ||||||
|   call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () |  | ||||||
|   // CHECK: rank = 3 offset = 0 sizes = [3, 1, 2] strides = [2, 2, 1] |  | ||||||
|   // CHECK: {{\[}}{{\[}}[0,    0]], |  | ||||||
|   // CHECK:       {{\[}}[0,    1]], |  | ||||||
|   // CHECK:       {{\[}}[1,    1]]] |  | ||||||
|   return |  | ||||||
| } |  | ||||||
|  | @ -17,7 +17,7 @@ func @dynamic_reshape_from_unranked( | ||||||
|   return %reshaped : tensor<?xf32> |   return %reshaped : tensor<?xf32> | ||||||
| } | } | ||||||
| // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) | // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) | ||||||
| // CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) | // CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) | ||||||
| // CHECK-SAME:   : (memref<*xf32>, memref<1xi32>) -> memref<?xf32> | // CHECK-SAME:   : (memref<*xf32>, memref<1xi32>) -> memref<?xf32> | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
|  | @ -30,5 +30,5 @@ func @dynamic_reshape_to_unranked( | ||||||
|   return %reshaped : tensor<*xf32> |   return %reshaped : tensor<*xf32> | ||||||
| } | } | ||||||
| // CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>) | // CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>) | ||||||
| // CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) | // CHECK-NEXT: memref_reshape [[ARG]]([[SHAPE]]) | ||||||
| // CHECK-SAME:   : (memref<?xf32>, memref<?xi32>) -> memref<*xf32> | // CHECK-SAME:   : (memref<?xf32>, memref<?xi32>) -> memref<*xf32> | ||||||
|  |  | ||||||
|  | @ -197,10 +197,11 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) { | ||||||
|   // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] |   // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] | ||||||
|   // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index |   // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index | ||||||
| 
 | 
 | ||||||
|   // CHECK: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast |   // CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to | ||||||
|   // CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) |   // CHECK-SAME: offset: [0], | ||||||
|   // CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] |   // CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]] | ||||||
|   // CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map> |   // CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] | ||||||
|  |   // CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map> | ||||||
| 
 | 
 | ||||||
|   // CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { |   // CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { | ||||||
|   // CHECK-SAME:   broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> |   // CHECK-SAME:   broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> | ||||||
|  |  | ||||||
|  | @ -267,7 +267,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index) | ||||||
|     %13 = absf %arg3 : f32 |     %13 = absf %arg3 : f32 | ||||||
|     linalg.yield %13 : f32 |     linalg.yield %13 : f32 | ||||||
|   } |   } | ||||||
|   %2 = lmhlo.reshape_memref_cast %1(%arg1) |   %2 = memref_reshape %1(%arg1) | ||||||
|       : (memref<?xf32>, memref<?xindex>) -> memref<*xf32> |       : (memref<?xf32>, memref<?xindex>) -> memref<*xf32> | ||||||
|   return %2 : memref<*xf32> |   return %2 : memref<*xf32> | ||||||
| } | } | ||||||
|  | @ -279,7 +279,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index) | ||||||
| //   CHECK-NOT:  scf.for | //   CHECK-NOT:  scf.for | ||||||
| //       CHECK:      linalg.generic | //       CHECK:      linalg.generic | ||||||
| //       CHECK:        absf | //       CHECK:        absf | ||||||
| //       CHECK:  reshape_memref_cast | //       CHECK:  memref_reshape | ||||||
| 
 | 
 | ||||||
| // TILED-LABEL: func @view_result | // TILED-LABEL: func @view_result | ||||||
| //   TILED-DAG:  %[[C2:.*]] = constant 2 | //   TILED-DAG:  %[[C2:.*]] = constant 2 | ||||||
|  | @ -288,7 +288,7 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index) | ||||||
| //   TILED-NOT:  scf.for | //   TILED-NOT:  scf.for | ||||||
| //       TILED:      linalg.generic | //       TILED:      linalg.generic | ||||||
| //       TILED:        absf | //       TILED:        absf | ||||||
| //       TILED:  reshape_memref_cast | //       TILED:  memref_reshape | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| // PLOOP-LABEL: func @view_result | // PLOOP-LABEL: func @view_result | ||||||
|  | @ -297,5 +297,5 @@ func @view_result(%arg0: memref<?xf32>, %arg1: memref<?xindex>, %arg2: index) | ||||||
| //   PLOOP-NOT:  scf.parallel | //   PLOOP-NOT:  scf.parallel | ||||||
| //       PLOOP:      linalg.generic | //       PLOOP:      linalg.generic | ||||||
| //       PLOOP:        absf | //       PLOOP:        absf | ||||||
| //       PLOOP:  reshape_memref_cast | //       PLOOP:  memref_reshape | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,65 +0,0 @@ | ||||||
| // RUN: mlir-hlo-opt %s -lower-affine -convert-scf-to-std -test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s |  | ||||||
| 
 |  | ||||||
| // CHECK-LABEL: func @static_memref_cast |  | ||||||
| func @static_memref_cast(%buf : memref<10x1x5xf32>) { |  | ||||||
|   %0 = lmhlo.static_memref_cast %buf |  | ||||||
|         : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| // CHECK: %[[INPUT_MEMREF_BLDR:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_3D:!.*]] |  | ||||||
| // CHECK: llvm.insertvalue |  | ||||||
| // CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE_2D:!.*]] |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE_3D]] |  | ||||||
| // CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float> |  | ||||||
| // CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE_2D]] |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE_3D]] |  | ||||||
| // CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float> |  | ||||||
| // CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE_2D]] |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[C2]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE_2D]] |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE_2D]] |  | ||||||
| // CHECK: %[[C5:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C5]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE_2D]] |  | ||||||
| // CHECK: %[[C5_:.*]] = llvm.mlir.constant(5 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C5_]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE_2D]] |  | ||||||
| // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE_2D]] |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| // CHECK-LABEL: func @dynamic_memref_cast |  | ||||||
| func @dynamic_memref_cast(%buf : memref<?x?xf32>) { |  | ||||||
|   %size_X = constant 10 : index |  | ||||||
|   %size_Y = constant 50 : index |  | ||||||
|   %stride_X = constant 1 : index |  | ||||||
|   %stride_Y = constant 0 : index |  | ||||||
|   %0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] |  | ||||||
|         : memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| // CHECK: %[[C10:.*]] = llvm.mlir.constant(10 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 |  | ||||||
| // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[MEMREF_BLDR_0:.*]] = llvm.mlir.undef : [[DESCRIPTOR_TYPE:!.*]] |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[IN_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF:.*]][0] : [[DESCRIPTOR_TYPE]] |  | ||||||
| // CHECK: %[[PTR:.*]] = llvm.bitcast %[[IN_PTR]] : !llvm.ptr<float> to !llvm.ptr<float> |  | ||||||
| // CHECK: %[[MEMREF_BLDR_1:.*]] = llvm.insertvalue %[[PTR]], %[[MEMREF_BLDR_0]][0] : [[DESCRIPTOR_TYPE]] |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[IN_ALIGNED_PTR:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][1] : [[DESCRIPTOR_TYPE]] |  | ||||||
| // CHECK: %[[ALIGNED_PTR:.*]] = llvm.bitcast %[[IN_ALIGNED_PTR]] : !llvm.ptr<float> to !llvm.ptr<float> |  | ||||||
| // CHECK: %[[MEMREF_BLDR_2:.*]] = llvm.insertvalue %[[ALIGNED_PTR]], %[[MEMREF_BLDR_1]][1] : [[DESCRIPTOR_TYPE]] |  | ||||||
| 
 |  | ||||||
| // CHECK: %[[SRC_OFFSET:.*]] = llvm.extractvalue %[[INPUT_MEMREF]][2] : [[DESCRIPTOR_TYPE]] |  | ||||||
| // CHECK: %[[MEMREF_BLDR_3:.*]] = llvm.insertvalue %[[SRC_OFFSET]], %[[MEMREF_BLDR_2]][2] : [[DESCRIPTOR_TYPE]] |  | ||||||
| // CHECK: %[[MEMREF_BLDR_4:.*]] = llvm.insertvalue %[[C10]], %[[MEMREF_BLDR_3]][3, 0] : [[DESCRIPTOR_TYPE]] |  | ||||||
| // CHECK: %[[MEMREF_BLDR_5:.*]] = llvm.insertvalue %[[C1]], %[[MEMREF_BLDR_4]][4, 0] : [[DESCRIPTOR_TYPE]] |  | ||||||
| // CHECK: %[[MEMREF_BLDR_6:.*]] = llvm.insertvalue %[[C50]], %[[MEMREF_BLDR_5]][3, 1] : [[DESCRIPTOR_TYPE]] |  | ||||||
| // CHECK: %[[MEMREF_BLDR_7:.*]] = llvm.insertvalue %[[C0]], %[[MEMREF_BLDR_6]][4, 1] : [[DESCRIPTOR_TYPE]] |  | ||||||
|  | @ -429,120 +429,6 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr | ||||||
| 
 | 
 | ||||||
| // ----- | // ----- | ||||||
| 
 | 
 | ||||||
| func @static_memref_cast(%in: memref<10x1xf32>) { |  | ||||||
|   %out = lmhlo.static_memref_cast %in |  | ||||||
|            : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| // CHECK-LABEL: func @static_memref_cast |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { |  | ||||||
|   // expected-error @+1 {{operand must have static shape}} |  | ||||||
|   %out = lmhlo.static_memref_cast %in |  | ||||||
|            : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { |  | ||||||
|   // expected-error @+1 {{result must have static shape}} |  | ||||||
|   %out = lmhlo.static_memref_cast %in |  | ||||||
|            : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @dynamic_memref_cast(%in: memref<?xf32>) { |  | ||||||
|   %size = constant 10 : index |  | ||||||
|   %step = constant 1 : index |  | ||||||
|   %out = lmhlo.dynamic_memref_cast %in(%size)[%step] |  | ||||||
|            : memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| // CHECK-LABEL: func @dynamic_memref_cast |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) { |  | ||||||
|   // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} |  | ||||||
|   %size = constant 10 : index |  | ||||||
|   %step = constant 1 : index |  | ||||||
|   %out = lmhlo.dynamic_memref_cast %in(%size)[%step] |  | ||||||
|            : memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| // CHECK-LABEL: func @reshape_memref_cast( |  | ||||||
| func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, |  | ||||||
|          %shape2: memref<2xi32>, %shape3: memref<?xi32>) { |  | ||||||
|   // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, |  | ||||||
|   // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32> |  | ||||||
| 
 |  | ||||||
|   // CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]] |  | ||||||
|   // CHECK-SAME:     : (memref<*xf32>, memref<1xi32>) -> memref<?xf32> |  | ||||||
|   %dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1) |  | ||||||
|                : (memref<*xf32>, memref<1xi32>) -> memref<?xf32> |  | ||||||
| 
 |  | ||||||
|   // CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]] |  | ||||||
|   // CHECK-SAME:     : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32> |  | ||||||
|   %dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2) |  | ||||||
|                : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32> |  | ||||||
| 
 |  | ||||||
|   // CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]] |  | ||||||
|   // CHECK-SAME:     : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32> |  | ||||||
|   %new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3) |  | ||||||
|                : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @reshape_memref_cast_element_type_mismatch( |  | ||||||
|        %buf: memref<*xf32>, %shape: memref<1xi32>) { |  | ||||||
|   // expected-error @+1 {{element types of source and destination memref types should be the same}} |  | ||||||
|   lmhlo.reshape_memref_cast %buf(%shape) |  | ||||||
|     : (memref<*xf32>, memref<1xi32>) -> memref<?xi32> |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @reshape_memref_cast_dst_ranked_shape_unranked( |  | ||||||
|        %buf: memref<*xf32>, %shape: memref<?xi32>) { |  | ||||||
|   // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} |  | ||||||
|   lmhlo.reshape_memref_cast %buf(%shape) |  | ||||||
|     : (memref<*xf32>, memref<?xi32>) -> memref<?xf32> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @reshape_memref_cast_dst_shape_rank_mismatch( |  | ||||||
|        %buf: memref<*xf32>, %shape: memref<1xi32>) { |  | ||||||
|   // expected-error @+1 {{length of shape operand differs from the result's memref rank}} |  | ||||||
|   lmhlo.reshape_memref_cast %buf(%shape) |  | ||||||
|     : (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| func @reshape_memref_cast_affine_map_is_not_identity( |  | ||||||
|         %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, |  | ||||||
|         %shape: memref<1xi32>) { |  | ||||||
|   // expected-error @+1 {{operand memref type should have identity affine map}} |  | ||||||
|   lmhlo.reshape_memref_cast %buf(%shape) |  | ||||||
|     : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) |  | ||||||
|     -> memref<8xf32> |  | ||||||
|   return |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // ----- |  | ||||||
| 
 |  | ||||||
| // CHECK-LABEL: func @atan2_memrefs | // CHECK-LABEL: func @atan2_memrefs | ||||||
| func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { | func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { | ||||||
|   "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () |   "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue