[HLO] Clean-up dynamic allocation in hlo-legalize-to-lhlo pass.

PiperOrigin-RevId: 335385243
This commit is contained in:
Alexander Belyaev 2020-10-05 03:54:51 -07:00 committed by TensorFlow MLIR Team
parent 7367eac074
commit d927e32451
1 changed files with 13 additions and 18 deletions

View File

@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
auto result_type = result.getType().dyn_cast<RankedTensorType>();
if (!result_type) {
result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects ranked results";
@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType());
Operation* op = result.getDefiningOp();
// Extract the required element out of the vector.
SmallVector<Value, 4> dynamic_operands;
for (auto shape_element : llvm::enumerate(result_type.getShape())) {
if (shape_element.value() != ShapedType::kDynamicSize) continue;
Value index = rewriter->create<ConstantOp>(
loc, rewriter->getIntegerAttr(rewriter->getIndexType(),
shape_element.index()));
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
ValueRange{index});
Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
Value alloc_operand =
rewriter->create<ExtractElementOp>(loc, shape_operand, index);
if (!alloc_operand.getType().isIndex()) {
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
rewriter->getIndexType());
@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
dynamic_operands.push_back(alloc_operand);
}
// Insert in front of op to ensure sizes are available.
OpBuilder allocBuilder(op);
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
return alloc;
return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
}
Value InsertAlloc(Location loc, OpResult result,
ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>();
auto result_type = result.getType().dyn_cast<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape()) {
result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects statically shaped results";
@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
SmallVector<Value, 1> results_shape;
auto status =
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
rewriter.replaceOp(
op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
return success();
}
};