From d927e324519bff6770075137b179ca63861201dd Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 5 Oct 2020 03:54:51 -0700 Subject: [PATCH] [HLO] Clean-up dynamic allocation in hlo-legalize-to-lhlo pass. PiperOrigin-RevId: 335385243 --- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index a808608..3485aff 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -45,7 +45,7 @@ using BaseOpConversion = BufferAssignmentOpConversionPattern; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast(); + auto result_type = result.getType().dyn_cast(); if (!result_type) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects ranked results"; @@ -53,17 +53,13 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, auto memref_type = MemRefType::get(result_type.getShape(), result_type.getElementType()); - Operation* op = result.getDefiningOp(); - // Extract the required element out of the vector. SmallVector dynamic_operands; for (auto shape_element : llvm::enumerate(result_type.getShape())) { if (shape_element.value() != ShapedType::kDynamicSize) continue; - Value index = rewriter->create( - loc, rewriter->getIntegerAttr(rewriter->getIndexType(), - shape_element.index())); - Value alloc_operand = rewriter->create(loc, shape_operand, - ValueRange{index}); + Value index = rewriter->create(loc, shape_element.index()); + Value alloc_operand = + rewriter->create(loc, shape_operand, index); if (!alloc_operand.getType().isIndex()) { alloc_operand = rewriter->create(loc, alloc_operand, rewriter->getIndexType()); @@ -71,15 +67,12 @@ Value InsertDynamicAllocAndDealloc(Location loc, Value result, dynamic_operands.push_back(alloc_operand); } - // Insert in front of op to ensure sizes are available. - OpBuilder allocBuilder(op); - auto alloc = allocBuilder.create(loc, memref_type, dynamic_operands); - return alloc; + return rewriter->create(loc, memref_type, dynamic_operands); } Value InsertAlloc(Location loc, OpResult result, ConversionPatternRewriter* rewriter) { - auto result_type = result.getType().dyn_cast(); + auto result_type = result.getType().dyn_cast(); if (!result_type || !result_type.hasStaticShape()) { result.getDefiningOp()->emitOpError() << "tensor to buffer conversion expects statically shaped results"; @@ -112,19 +105,21 @@ class HloToLhloOpConverter : public BaseOpConversion { buffer_args.push_back( InsertAlloc(op->getLoc(), result.value(), &rewriter)); } else { - SmallVector results_shape; auto shape_type_op = dyn_cast(op); if (!shape_type_op) return failure(); - if (failed( - shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) - return failure(); + + SmallVector results_shape; + auto status = + shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); + if (failed(status)) return failure(); buffer_args.push_back(InsertDynamicAllocAndDealloc( op->getLoc(), result.value(), results_shape.front(), &rewriter)); } } rewriter.create>(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); - rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); + rewriter.replaceOp( + op, llvm::makeArrayRef(buffer_args).drop_front(operands.size())); return success(); } };