[HLO] Clean-up dynamic allocation in hlo-legalize-to-lhlo pass.
PiperOrigin-RevId: 335385243
This commit is contained in:
parent
7367eac074
commit
d927e32451
|
@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
||||||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||||
Value shape_operand,
|
Value shape_operand,
|
||||||
ConversionPatternRewriter* rewriter) {
|
ConversionPatternRewriter* rewriter) {
|
||||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!result_type) {
|
if (!result_type) {
|
||||||
result.getDefiningOp()->emitOpError()
|
result.getDefiningOp()->emitOpError()
|
||||||
<< "tensor to buffer conversion expects ranked results";
|
<< "tensor to buffer conversion expects ranked results";
|
||||||
|
@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||||
auto memref_type =
|
auto memref_type =
|
||||||
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
MemRefType::get(result_type.getShape(), result_type.getElementType());
|
||||||
|
|
||||||
Operation* op = result.getDefiningOp();
|
|
||||||
|
|
||||||
// Extract the required element out of the vector.
|
// Extract the required element out of the vector.
|
||||||
SmallVector<Value, 4> dynamic_operands;
|
SmallVector<Value, 4> dynamic_operands;
|
||||||
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
|
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
|
||||||
if (shape_element.value() != ShapedType::kDynamicSize) continue;
|
if (shape_element.value() != ShapedType::kDynamicSize) continue;
|
||||||
Value index = rewriter->create<ConstantOp>(
|
Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
|
||||||
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
|
Value alloc_operand =
|
||||||
shape_element.index()));
|
rewriter->create<ExtractElementOp>(loc, shape_operand, index);
|
||||||
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
|
|
||||||
ValueRange{index});
|
|
||||||
if (!alloc_operand.getType().isIndex()) {
|
if (!alloc_operand.getType().isIndex()) {
|
||||||
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
|
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
|
||||||
rewriter->getIndexType());
|
rewriter->getIndexType());
|
||||||
|
@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||||
dynamic_operands.push_back(alloc_operand);
|
dynamic_operands.push_back(alloc_operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert in front of op to ensure sizes are available.
|
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
|
||||||
OpBuilder allocBuilder(op);
|
|
||||||
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
|
|
||||||
return alloc;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Value InsertAlloc(Location loc, OpResult result,
|
Value InsertAlloc(Location loc, OpResult result,
|
||||||
ConversionPatternRewriter* rewriter) {
|
ConversionPatternRewriter* rewriter) {
|
||||||
auto result_type = result.getType().dyn_cast<ShapedType>();
|
auto result_type = result.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!result_type || !result_type.hasStaticShape()) {
|
if (!result_type || !result_type.hasStaticShape()) {
|
||||||
result.getDefiningOp()->emitOpError()
|
result.getDefiningOp()->emitOpError()
|
||||||
<< "tensor to buffer conversion expects statically shaped results";
|
<< "tensor to buffer conversion expects statically shaped results";
|
||||||
|
@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||||
buffer_args.push_back(
|
buffer_args.push_back(
|
||||||
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
InsertAlloc(op->getLoc(), result.value(), &rewriter));
|
||||||
} else {
|
} else {
|
||||||
SmallVector<Value, 1> results_shape;
|
|
||||||
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
|
||||||
if (!shape_type_op) return failure();
|
if (!shape_type_op) return failure();
|
||||||
if (failed(
|
|
||||||
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
|
SmallVector<Value, 1> results_shape;
|
||||||
return failure();
|
auto status =
|
||||||
|
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
|
||||||
|
if (failed(status)) return failure();
|
||||||
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
buffer_args.push_back(InsertDynamicAllocAndDealloc(
|
||||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||||
buffer_args, op->getAttrs());
|
buffer_args, op->getAttrs());
|
||||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
rewriter.replaceOp(
|
||||||
|
op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue