diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 6710d37..24f12d4 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -87,6 +87,32 @@ Value InsertAlloc(Location loc, OpResult result, return alloc; } +/// Converts the results of the operation `op` to memref types and append them +/// to the `results` vector. +LogicalResult ConvertResults(Operation* op, SmallVectorImpl& results, + ConversionPatternRewriter& rewriter) { + for (auto result : llvm::enumerate(op->getResults())) { + RankedTensorType resultType = + result.value().getType().dyn_cast(); + if (!resultType) return failure(); + + if (resultType.hasStaticShape()) { + results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter)); + continue; + } + auto shape_type_op = dyn_cast(op); + if (!shape_type_op) return failure(); + + SmallVector results_shape; + auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); + if (failed(status)) return failure(); + results.push_back( + InsertDynamicAllocAndDealloc(op->getLoc(), result.value(), + results_shape[result.index()], &rewriter)); + } + return success(); +} + template class HloToLhloOpConverter : public BaseOpConversion { public: @@ -95,29 +121,8 @@ class HloToLhloOpConverter : public BaseOpConversion { HloOpTy hloOp, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { Operation* op = hloOp.getOperation(); - const auto& original_results = op->getResults(); SmallVector buffer_args(operands.begin(), operands.end()); - for (auto result : llvm::enumerate(original_results)) { - RankedTensorType resultType = - result.value().getType().dyn_cast(); - if (!resultType) { - return failure(); - } - if (resultType.hasStaticShape()) { - buffer_args.push_back( - InsertAlloc(op->getLoc(), result.value(), &rewriter)); - } else { - auto shape_type_op = dyn_cast(op); - if (!shape_type_op) return failure(); - - SmallVector results_shape; - auto status = - shape_type_op.reifyReturnTypeShapes(rewriter, results_shape); - if (failed(status)) return failure(); - buffer_args.push_back(InsertDynamicAllocAndDealloc( - op->getLoc(), result.value(), results_shape.front(), &rewriter)); - } - } + if (failed(ConvertResults(op, buffer_args, rewriter))) return failure(); rewriter.create>(op->getLoc(), llvm::None, buffer_args, op->getAttrs()); rewriter.replaceOp( @@ -139,28 +144,8 @@ class HloToLhloOpConverter : public BaseOpConversion { mhlo::DotOp hloOp, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { Operation* op = hloOp.getOperation(); - const auto& original_results = op->getResults(); SmallVector buffer_args(operands.begin(), operands.end()); - for (auto result : llvm::enumerate(original_results)) { - RankedTensorType resultType = - result.value().getType().dyn_cast(); - if (!resultType) { - return failure(); - } - if (resultType.hasStaticShape()) { - buffer_args.push_back( - InsertAlloc(op->getLoc(), result.value(), &rewriter)); - } else { - SmallVector results_shape; - auto shape_type_op = dyn_cast(op); - if (!shape_type_op) return failure(); - if (failed( - shape_type_op.reifyReturnTypeShapes(rewriter, results_shape))) - return failure(); - buffer_args.push_back(InsertDynamicAllocAndDealloc( - op->getLoc(), result.value(), results_shape.front(), &rewriter)); - } - } + if (failed(ConvertResults(op, buffer_args, rewriter))) return failure(); // TODO(silvasean): Move this helper to MLIR core. auto make_elements_attr = [&rewriter](ArrayRef integers) {