[HLO] Clean-up dynamic allocation in hlo-legalize-to-lhlo pass.
PiperOrigin-RevId: 335385243
This commit is contained in:
parent
7367eac074
commit
d927e32451
|
@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
|||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||
Value shape_operand,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type) {
|
||||
result.getDefiningOp()->emitOpError()
|
||||
<< "tensor to buffer conversion expects ranked results";
|
||||
|
@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
|||
auto memref_type =
|
||||
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||
|
||||
Operation* op = result.getDefiningOp();
|
||||
|
||||
// Extract the required element out of the vector.
|
||||
SmallVector<Value, 4> dynamic_operands;
|
||||
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
|
||||
if (shape_element.value() != ShapedType::kDynamicSize) continue;
|
||||
Value index = rewriter->create<ConstantOp>(
|
||||
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
|
||||
shape_element.index()));
|
||||
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
|
||||
ValueRange{index});
|
||||
Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
|
||||
Value alloc_operand =
|
||||
rewriter->create<ExtractElementOp>(loc, shape_operand, index);
|
||||
if (!alloc_operand.getType().isIndex()) {
|
||||
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
|
||||
rewriter->getIndexType());
|
||||
|
@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
|||
dynamic_operands.push_back(alloc_operand);
|
||||
}
|
||||
|
||||
// Insert in front of op to ensure sizes are available.
|
||||
OpBuilder allocBuilder(op);
|
||||
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||
return alloc;
|
||||
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||
}
|
||||
|
||||
Value InsertAlloc(Location loc, OpResult result,
|
||||
ConversionPatternRewriter* rewriter) {
|
||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
||||
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_type || !result_type.hasStaticShape()) {
|
||||
result.getDefiningOp()->emitOpError()
|
||||
<< "tensor to buffer conversion expects statically shaped results";
|
||||
|
@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
|||
buffer_args.push_back(
|
||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||
} else {
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||
if (!shape_type_op) return failure();
|
||||
if (failed(
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value, 1> results_shape;
|
||||
auto status =
|
||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||
if (failed(status)) return failure();
|
||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
rewriter.replaceOp(
|
||||
op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue