[HLO] Clean-up dynamic allocation in hlo-legalize-to-lhlo pass.

PiperOrigin-RevId: 335385243
This commit is contained in:
Alexander Belyaev 2020-10-05 03:54:51 -07:00 committed by TensorFlow MLIR Team
parent 7367eac074
commit d927e32451
1 changed files with 13 additions and 18 deletions

View File

@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand, Value shape_operand,
ConversionPatternRewriter* rewriter) { ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>(); auto result_type = result.getType().dyn_cast<RankedTensorType>();
if (!result_type) { if (!result_type) {
result.getDefiningOp()->emitOpError() result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects ranked results"; << "tensor to buffer conversion expects ranked results";
@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
auto memref_type = auto memref_type =
MemRefType::get(result_type.getShape(), result_type.getElementType()); MemRefType::get(result_type.getShape(), result_type.getElementType());
Operation* op = result.getDefiningOp();
// Extract the required element out of the vector. // Extract the required element out of the vector.
SmallVector<Value, 4> dynamic_operands; SmallVector<Value, 4> dynamic_operands;
for (auto shape_element : llvm::enumerate(result_type.getShape())) { for (auto shape_element : llvm::enumerate(result_type.getShape())) {
if (shape_element.value() != ShapedType::kDynamicSize) continue; if (shape_element.value() != ShapedType::kDynamicSize) continue;
Value index = rewriter->create<ConstantOp>( Value index = rewriter->create<ConstantIndexOp>(loc, shape_element.index());
loc, rewriter->getIntegerAttr(rewriter->getIndexType(), Value alloc_operand =
shape_element.index())); rewriter->create<ExtractElementOp>(loc, shape_operand, index);
Value alloc_operand = rewriter->create<ExtractElementOp>(loc, shape_operand,
ValueRange{index});
if (!alloc_operand.getType().isIndex()) { if (!alloc_operand.getType().isIndex()) {
alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand, alloc_operand = rewriter->create<IndexCastOp>(loc, alloc_operand,
rewriter->getIndexType()); rewriter->getIndexType());
@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result,
dynamic_operands.push_back(alloc_operand); dynamic_operands.push_back(alloc_operand);
} }
// Insert in front of op to ensure sizes are available. return rewriter->create<AllocOp>(loc, memref_type, dynamic_operands);
OpBuilder allocBuilder(op);
auto alloc = allocBuilder.create<AllocOp>(loc, memref_type, dynamic_operands);
return alloc;
} }
Value InsertAlloc(Location loc, OpResult result, Value InsertAlloc(Location loc, OpResult result,
ConversionPatternRewriter* rewriter) { ConversionPatternRewriter* rewriter) {
auto result_type = result.getType().dyn_cast<ShapedType>(); auto result_type = result.getType().dyn_cast<RankedTensorType>();
if (!result_type || !result_type.hasStaticShape()) { if (!result_type || !result_type.hasStaticShape()) {
result.getDefiningOp()->emitOpError() result.getDefiningOp()->emitOpError()
<< "tensor to buffer conversion expects statically shaped results"; << "tensor to buffer conversion expects statically shaped results";
@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
buffer_args.push_back( buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter)); InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else { } else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op); auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure(); if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) SmallVector<Value, 1> results_shape;
return failure(); auto status =
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc( buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter)); op->getLoc(), result.value(), results_shape.front(), &rewriter));
} }
} }
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None, rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs()); buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size())); rewriter.replaceOp(
op, llvm::makeArrayRef(buffer_args).drop_front(operands.size()));
return success(); return success();
} }
}; };