Implement lowering of lmhlo.reshape_memref_cast to LLVM for unknown length shape operand.

PiperOrigin-RevId: 320959625
This commit is contained in:
Stephan Herhut 2020-07-13 15:26:53 +00:00 committed by Mehdi Amini
parent d166b66cba
commit 86f290896d
1 changed files with 128 additions and 17 deletions

View File

@ -133,8 +133,8 @@ struct ReshapeMemRefCastOpConverter
Location loc = op->getLoc();
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
Type dst_type = reshape_op.getResult().getType();
auto element_type = dst_type.cast<ShapedType>().getElementType();
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
auto element_type = dst_type.getElementType();
auto shape = reshape_op.shape();
@ -162,18 +162,17 @@ struct ReshapeMemRefCastOpConverter
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType())
.cast<LLVM::LLVMType>();
auto llvmIndexTyPtr = llvmIndexTy.getPointerTo();
auto llvm_index_type = typeConverter.getIndexType();
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexTy,
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
for (int i = shape_length - 1; i >= 0; --i) {
Value pos = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexTy,
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value ptr = rewriter.create<LLVM::GEPOp>(
loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc),
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
ValueRange{pos});
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
desc.setSize(rewriter, loc, i, extracted_size);
@ -188,7 +187,7 @@ struct ReshapeMemRefCastOpConverter
rewriter.replaceOp(op, {desc});
} else {
Value rank = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexTy,
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
Value alloca =
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
@ -199,15 +198,127 @@ struct ReshapeMemRefCastOpConverter
{rank, void_ptr});
rewriter.replaceOp(op, {unranked_desc});
}
} else {
/*
* TODO(pifon, herhut):
* Compute strides with llvm.loop;
* Use UnrankedMemrefDescr::ComputeSize with Alloca;
* Set all the fields using getelementptr.
*/
return failure();
return success();
}
// The shape is a rank-1 tensor with unknown length.
Value result_rank = shape_desc.size(rewriter, loc, 0);
// TODO(herhut): Propely handle address spaces.
unsigned address_space = 0;
auto target_type =
typeConverter
.convertType(UnrankedMemRefType::get(element_type, address_space))
.cast<LLVM::LLVMType>();
// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
UnrankedMemRefDescriptor target_desc =
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
target_desc.setRank(rewriter, loc, result_rank);
SmallVector<Value, 1> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
{target_desc}, sizes);
auto void_ptr_type =
LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
loc, void_ptr_type, sizes.front(), llvm::None);
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
// Fill the fixed parts. For this, we cast to a 0-D memref.
auto zero_d_memref_type = MemRefType::get({}, element_type);
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
loc,
typeConverter.convertType(zero_d_memref_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space),
ranked_desc_mem);
// Some common constants. Use 32 bit where required by gep struct indexes.
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
Value zero_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
Value zero = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(0));
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(1));
Value two = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(2));
// Set base_pointer and aligned pointer.
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space)
.getPointerTo(address_space);
auto base_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
// Set offset.
auto index_ptr_type =
typeConverter.getIndexType().getPointerTo(address_space);
auto offset_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
// Use the offset pointer as base for further addressing. Copy over the
// new shape and compute strides. For this, we need to create a loop from
// rank - 1 to 0.
Value one_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, offset_gep, ValueRange({one}));
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
auto result_rank_minus_one =
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
Block *init_block = rewriter.getInsertionBlock();
Block *cond_block =
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToEnd(init_block);
rewriter.create<LLVM::BrOp>(
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
rewriter.setInsertionPointToStart(cond_block);
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
auto pred = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()),
LLVM::ICmpPredicate::sge, index_arg, zero_index);
Block *body_block =
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToStart(body_block);
// Copy size from shape to descriptor.
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
// Write stride value and compute next one.
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
auto next_stride =
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
// Decrement loop counter and branch back.
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
cond_block);
Block *remainder =
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
// Hook up the cond exit to the remainder.
rewriter.setInsertionPointToEnd(cond_block);
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
remainder, ValueRange());
// Reset position to beginning of new remainder block.
rewriter.setInsertionPointToStart(remainder);
rewriter.replaceOp(op, {target_desc});
return success();
}