diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index 7d83589..7f18b8b 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -123,13 +123,138 @@ struct DynamicMemRefCastOpConverter } }; +struct ReshapeMemRefCastOpConverter + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + auto reshape_op = cast(op); + Type dst_type = reshape_op.getResult().getType(); + auto element_type = dst_type.cast().getElementType(); + + auto shape = reshape_op.shape(); + + ReshapeMemRefCastOp::Adaptor operands_adaptor(operands); + PtrsAndOffset ptrs_n_offset = ExtractMemRefPtrsAndOffset( + loc, reshape_op.operand(), operands_adaptor.operand(), &rewriter); + + MemRefDescriptor shape_desc(operands_adaptor.shape()); + + auto shape_memref_type = shape.getType().cast(); + + if (shape_memref_type.hasStaticShape()) { + auto shape_length = shape_memref_type.getDimSize(0); + + MemRefType targetMemRefType = MemRefType::get( + SmallVector(shape_length, 1), element_type); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + // Create descriptor. + auto desc = + MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + desc.setAllocatedPtr(rewriter, loc, ptrs_n_offset.allocated_ptr); + desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr); + desc.setOffset(rewriter, loc, ptrs_n_offset.offset); + + auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType()) + .cast(); + auto llvmIndexTyPtr = llvmIndexTy.getPointerTo(); + Value stride_carried = rewriter.create( + loc, llvmIndexTy, + rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + for (int i = shape_length - 1; i >= 0; --i) { + Value pos = rewriter.create( + loc, llvmIndexTy, + rewriter.getIntegerAttr(rewriter.getIndexType(), i)); + Value ptr = rewriter.create( + loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc), + ValueRange{pos}); + Value extracted_size = rewriter.create(loc, ptr); + desc.setSize(rewriter, loc, i, extracted_size); + desc.setStride(rewriter, loc, i, stride_carried); + // Update stride + if (i > 0) { + stride_carried = + rewriter.create(loc, stride_carried, extracted_size); + } + } + if (dst_type.isa()) { + rewriter.replaceOp(op, {desc}); + } else { + Value rank = rewriter.create( + loc, llvmIndexTy, + rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length)); + Value alloca = + typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter); + Value void_ptr = + rewriter.create(loc, getVoidPtrType(), alloca); + auto unranked_desc = UnrankedMemRefDescriptor::pack( + rewriter, loc, typeConverter, dst_type.cast(), + {rank, void_ptr}); + rewriter.replaceOp(op, {unranked_desc}); + } + } else { + /* + * TODO(pifon, herhut): + * Compute strides with llvm.loop; + * Use UnrankedMemrefDescr::ComputeSize with Alloca; + * Set all the fields using getelementptr. + */ + return failure(); + } + return success(); + } + + private: + struct PtrsAndOffset { + Value allocated_ptr; + Value aligned_ptr; + Value offset; + }; + + PtrsAndOffset ExtractMemRefPtrsAndOffset( + Location loc, Value originalOperand, Value convertedOperand, + ConversionPatternRewriter *rewriter) const { + Type operandType = originalOperand.getType(); + Value descriptor_ptr; + if (operandType.isa()) { + descriptor_ptr = convertedOperand; + } else { + UnrankedMemRefDescriptor unranked_descriptor(convertedOperand); + Value underlying_desc_ptr = + unranked_descriptor.memRefDescPtr(*rewriter, loc); + + Type element_type = + operandType.cast().getElementType(); + LLVM::LLVMType memref_type_0d = + typeConverter.convertType(MemRefType::get(/*shape=*/{}, element_type)) + .cast(); + descriptor_ptr = rewriter->create( + loc, memref_type_0d.getPointerTo(), underlying_desc_ptr); + descriptor_ptr = rewriter->create(loc, descriptor_ptr); + } + MemRefDescriptor descriptor(descriptor_ptr); + PtrsAndOffset result; + result.allocated_ptr = descriptor.allocatedPtr(*rewriter, loc); + result.aligned_ptr = descriptor.alignedPtr(*rewriter, loc); + result.offset = descriptor.offset(*rewriter, loc); + return result; + } +}; + } // namespace void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { - patterns->insert( - *converter, options); + patterns->insert(*converter, options); } } // namespace lmhlo