From 86f290896daaab8e6ad3c31131040d332aafbe7b Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Mon, 13 Jul 2020 15:26:53 +0000 Subject: [PATCH] Implement lowering of lmhlo.reshape_memref_cast to LLVM for unknown length shape operand. PiperOrigin-RevId: 320959625 --- .../mhlo/transforms/lhlo_legalize_to_llvm.cc | 145 ++++++++++++++++-- 1 file changed, 128 insertions(+), 17 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index 7f18b8b..0ed1b18 100644 --- a/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -133,8 +133,8 @@ struct ReshapeMemRefCastOpConverter Location loc = op->getLoc(); auto reshape_op = cast(op); - Type dst_type = reshape_op.getResult().getType(); - auto element_type = dst_type.cast().getElementType(); + auto dst_type = reshape_op.getResult().getType().cast(); + auto element_type = dst_type.getElementType(); auto shape = reshape_op.shape(); @@ -162,18 +162,17 @@ struct ReshapeMemRefCastOpConverter desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr); desc.setOffset(rewriter, loc, ptrs_n_offset.offset); - auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType()) - .cast(); - auto llvmIndexTyPtr = llvmIndexTy.getPointerTo(); + auto llvm_index_type = typeConverter.getIndexType(); + auto llvm_index_ptr_type = llvm_index_type.getPointerTo(); Value stride_carried = rewriter.create( - loc, llvmIndexTy, + loc, llvm_index_type, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); for (int i = shape_length - 1; i >= 0; --i) { Value pos = rewriter.create( - loc, llvmIndexTy, + loc, llvm_index_type, rewriter.getIntegerAttr(rewriter.getIndexType(), i)); Value ptr = rewriter.create( - loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc), + loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc), ValueRange{pos}); Value extracted_size = rewriter.create(loc, ptr); desc.setSize(rewriter, loc, i, extracted_size); @@ -188,7 +187,7 @@ struct ReshapeMemRefCastOpConverter rewriter.replaceOp(op, {desc}); } else { Value rank = rewriter.create( - loc, llvmIndexTy, + loc, llvm_index_type, rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length)); Value alloca = typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter); @@ -199,15 +198,127 @@ struct ReshapeMemRefCastOpConverter {rank, void_ptr}); rewriter.replaceOp(op, {unranked_desc}); } - } else { - /* - * TODO(pifon, herhut): - * Compute strides with llvm.loop; - * Use UnrankedMemrefDescr::ComputeSize with Alloca; - * Set all the fields using getelementptr. - */ - return failure(); + return success(); } + + // The shape is a rank-1 tensor with unknown length. + Value result_rank = shape_desc.size(rewriter, loc, 0); + // TODO(herhut): Propely handle address spaces. + unsigned address_space = 0; + auto target_type = + typeConverter + .convertType(UnrankedMemRefType::get(element_type, address_space)) + .cast(); + // Create the unranked memref descriptor that holds the ranked one. The + // inner descriptor is allocated on stack. + UnrankedMemRefDescriptor target_desc = + UnrankedMemRefDescriptor::undef(rewriter, loc, target_type); + target_desc.setRank(rewriter, loc, result_rank); + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, + {target_desc}, sizes); + auto void_ptr_type = + LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect()); + Value ranked_desc_mem = rewriter.create( + loc, void_ptr_type, sizes.front(), llvm::None); + target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem); + + // Fill the fixed parts. For this, we cast to a 0-D memref. + auto zero_d_memref_type = MemRefType::get({}, element_type); + Value as_zero_d = rewriter.create( + loc, + typeConverter.convertType(zero_d_memref_type) + .cast() + .getPointerTo(address_space), + ranked_desc_mem); + // Some common constants. Use 32 bit where required by gep struct indexes. + auto int32_type = typeConverter.convertType(rewriter.getI32Type()); + Value zero_index = rewriter.create( + loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0)); + Value zero = rewriter.create( + loc, int32_type, rewriter.getI32IntegerAttr(0)); + Value one = rewriter.create( + loc, int32_type, rewriter.getI32IntegerAttr(1)); + Value two = rewriter.create( + loc, int32_type, rewriter.getI32IntegerAttr(2)); + // Set base_pointer and aligned pointer. + auto element_ptr_ptr_type = typeConverter.convertType(element_type) + .cast() + .getPointerTo(address_space) + .getPointerTo(address_space); + auto base_gep = rewriter.create( + loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero})); + rewriter.create(loc, ptrs_n_offset.allocated_ptr, base_gep); + auto aligned_gep = rewriter.create( + loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one})); + rewriter.create(loc, ptrs_n_offset.aligned_ptr, aligned_gep); + // Set offset. + auto index_ptr_type = + typeConverter.getIndexType().getPointerTo(address_space); + auto offset_gep = rewriter.create( + loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two})); + rewriter.create(loc, ptrs_n_offset.offset, offset_gep); + + // Use the offset pointer as base for further addressing. Copy over the + // new shape and compute strides. For this, we need to create a loop from + // rank - 1 to 0. + Value one_index = rewriter.create( + loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1)); + auto target_shape_base = rewriter.create( + loc, index_ptr_type, offset_gep, ValueRange({one})); + auto target_strides_base = rewriter.create( + loc, index_ptr_type, target_shape_base, ValueRange({result_rank})); + auto shape_ptr = shape_desc.alignedPtr(rewriter, loc); + auto result_rank_minus_one = + rewriter.create(loc, result_rank, one_index); + + Block *init_block = rewriter.getInsertionBlock(); + Block *cond_block = + rewriter.splitBlock(init_block, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToEnd(init_block); + rewriter.create( + loc, ValueRange({result_rank_minus_one, one_index}), cond_block); + rewriter.setInsertionPointToStart(cond_block); + auto index_arg = cond_block->addArgument(typeConverter.getIndexType()); + auto stride_arg = cond_block->addArgument(typeConverter.getIndexType()); + auto pred = rewriter.create( + loc, LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()), + LLVM::ICmpPredicate::sge, index_arg, zero_index); + + Block *body_block = + rewriter.splitBlock(cond_block, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(body_block); + + // Copy size from shape to descriptor. + auto size_load_gep = rewriter.create( + loc, index_ptr_type, shape_ptr, ValueRange{index_arg}); + auto extracted_size = rewriter.create(loc, size_load_gep); + auto size_store_gep = rewriter.create( + loc, index_ptr_type, target_shape_base, ValueRange({index_arg})); + rewriter.create(loc, extracted_size, size_store_gep); + // Write stride value and compute next one. + auto stride_store_gep = rewriter.create( + loc, index_ptr_type, target_strides_base, ValueRange({index_arg})); + rewriter.create(loc, stride_arg, stride_store_gep); + auto next_stride = + rewriter.create(loc, stride_arg, extracted_size); + + // Decrement loop counter and branch back. + auto decrement = rewriter.create(loc, index_arg, one_index); + rewriter.create(loc, ValueRange({decrement, next_stride}), + cond_block); + + Block *remainder = + rewriter.splitBlock(body_block, rewriter.getInsertionPoint()); + + // Hook up the cond exit to the remainder. + rewriter.setInsertionPointToEnd(cond_block); + rewriter.create(loc, pred, body_block, ValueRange(), + remainder, ValueRange()); + + // Reset position to beginning of new remainder block. + rewriter.setInsertionPointToStart(remainder); + rewriter.replaceOp(op, {target_desc}); return success(); }