Implement lowering of lmhlo.reshape_memref_cast to LLVM for unknown length shape operand.
PiperOrigin-RevId: 320959625
This commit is contained in:
parent
d166b66cba
commit
86f290896d
|
@ -133,8 +133,8 @@ struct ReshapeMemRefCastOpConverter
|
|||
Location loc = op->getLoc();
|
||||
|
||||
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
|
||||
Type dst_type = reshape_op.getResult().getType();
|
||||
auto element_type = dst_type.cast<ShapedType>().getElementType();
|
||||
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
|
||||
auto element_type = dst_type.getElementType();
|
||||
|
||||
auto shape = reshape_op.shape();
|
||||
|
||||
|
@ -162,18 +162,17 @@ struct ReshapeMemRefCastOpConverter
|
|||
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
|
||||
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
|
||||
|
||||
auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType())
|
||||
.cast<LLVM::LLVMType>();
|
||||
auto llvmIndexTyPtr = llvmIndexTy.getPointerTo();
|
||||
auto llvm_index_type = typeConverter.getIndexType();
|
||||
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
|
||||
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexTy,
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
|
||||
for (int i = shape_length - 1; i >= 0; --i) {
|
||||
Value pos = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexTy,
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
|
||||
Value ptr = rewriter.create<LLVM::GEPOp>(
|
||||
loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc),
|
||||
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
|
||||
ValueRange{pos});
|
||||
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
|
||||
desc.setSize(rewriter, loc, i, extracted_size);
|
||||
|
@ -188,7 +187,7 @@ struct ReshapeMemRefCastOpConverter
|
|||
rewriter.replaceOp(op, {desc});
|
||||
} else {
|
||||
Value rank = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, llvmIndexTy,
|
||||
loc, llvm_index_type,
|
||||
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
|
||||
Value alloca =
|
||||
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
|
||||
|
@ -199,15 +198,127 @@ struct ReshapeMemRefCastOpConverter
|
|||
{rank, void_ptr});
|
||||
rewriter.replaceOp(op, {unranked_desc});
|
||||
}
|
||||
} else {
|
||||
/*
|
||||
* TODO(pifon, herhut):
|
||||
* Compute strides with llvm.loop;
|
||||
* Use UnrankedMemrefDescr::ComputeSize with Alloca;
|
||||
* Set all the fields using getelementptr.
|
||||
*/
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
// The shape is a rank-1 tensor with unknown length.
|
||||
Value result_rank = shape_desc.size(rewriter, loc, 0);
|
||||
// TODO(herhut): Propely handle address spaces.
|
||||
unsigned address_space = 0;
|
||||
auto target_type =
|
||||
typeConverter
|
||||
.convertType(UnrankedMemRefType::get(element_type, address_space))
|
||||
.cast<LLVM::LLVMType>();
|
||||
// Create the unranked memref descriptor that holds the ranked one. The
|
||||
// inner descriptor is allocated on stack.
|
||||
UnrankedMemRefDescriptor target_desc =
|
||||
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
|
||||
target_desc.setRank(rewriter, loc, result_rank);
|
||||
SmallVector<Value, 1> sizes;
|
||||
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
|
||||
{target_desc}, sizes);
|
||||
auto void_ptr_type =
|
||||
LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
|
||||
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
|
||||
loc, void_ptr_type, sizes.front(), llvm::None);
|
||||
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
|
||||
|
||||
// Fill the fixed parts. For this, we cast to a 0-D memref.
|
||||
auto zero_d_memref_type = MemRefType::get({}, element_type);
|
||||
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
|
||||
loc,
|
||||
typeConverter.convertType(zero_d_memref_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space),
|
||||
ranked_desc_mem);
|
||||
// Some common constants. Use 32 bit where required by gep struct indexes.
|
||||
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
|
||||
Value zero_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(0));
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(1));
|
||||
Value two = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, int32_type, rewriter.getI32IntegerAttr(2));
|
||||
// Set base_pointer and aligned pointer.
|
||||
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
|
||||
.cast<LLVM::LLVMType>()
|
||||
.getPointerTo(address_space)
|
||||
.getPointerTo(address_space);
|
||||
auto base_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
|
||||
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
|
||||
// Set offset.
|
||||
auto index_ptr_type =
|
||||
typeConverter.getIndexType().getPointerTo(address_space);
|
||||
auto offset_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
|
||||
|
||||
// Use the offset pointer as base for further addressing. Copy over the
|
||||
// new shape and compute strides. For this, we need to create a loop from
|
||||
// rank - 1 to 0.
|
||||
Value one_index = rewriter.create<LLVM::ConstantOp>(
|
||||
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
|
||||
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, offset_gep, ValueRange({one}));
|
||||
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
|
||||
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
|
||||
auto result_rank_minus_one =
|
||||
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
|
||||
|
||||
Block *init_block = rewriter.getInsertionBlock();
|
||||
Block *cond_block =
|
||||
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToEnd(init_block);
|
||||
rewriter.create<LLVM::BrOp>(
|
||||
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
|
||||
rewriter.setInsertionPointToStart(cond_block);
|
||||
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
|
||||
auto pred = rewriter.create<LLVM::ICmpOp>(
|
||||
loc, LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()),
|
||||
LLVM::ICmpPredicate::sge, index_arg, zero_index);
|
||||
|
||||
Block *body_block =
|
||||
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
|
||||
rewriter.setInsertionPointToStart(body_block);
|
||||
|
||||
// Copy size from shape to descriptor.
|
||||
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
|
||||
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
|
||||
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
|
||||
// Write stride value and compute next one.
|
||||
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
|
||||
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
|
||||
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
|
||||
auto next_stride =
|
||||
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
|
||||
|
||||
// Decrement loop counter and branch back.
|
||||
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
|
||||
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
|
||||
cond_block);
|
||||
|
||||
Block *remainder =
|
||||
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
|
||||
|
||||
// Hook up the cond exit to the remainder.
|
||||
rewriter.setInsertionPointToEnd(cond_block);
|
||||
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
|
||||
remainder, ValueRange());
|
||||
|
||||
// Reset position to beginning of new remainder block.
|
||||
rewriter.setInsertionPointToStart(remainder);
|
||||
rewriter.replaceOp(op, {target_desc});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue