diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 8425f0f..d14072a 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -42,9 +42,6 @@ namespace { template using BaseOpConversion = BufferAssignmentOpConversionPattern; -using StdReturnOpConverter = - detail::BufferAssignmentReturnOpConverter; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -272,27 +269,21 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); - // Create new block arguments with correct type. + // Convert the region signature to memref and add extra result. auto& entry_block = new_op.body().front(); - int original_arg_count = entry_block.getNumArguments(); - for (int i = 0; i < original_arg_count; ++i) { - auto old_arg = entry_block.getArgument(i); - auto old_type = old_arg.getType().cast(); + TypeConverter::SignatureConversion sig_conversion( + entry_block.getNumArguments() + 1); + for (auto arg : entry_block.getArguments()) { + auto old_type = arg.getType().cast(); auto new_type = MemRefType::get(old_type.getShape(), old_type.getElementType()); - auto new_arg = entry_block.addArgument(new_type); - rewriter.replaceUsesOfBlockArgument(old_arg, new_arg); + sig_conversion.addInputs(arg.getArgNumber(), new_type); } - // Add an argument for the result. - entry_block.addArgument( - entry_block.getArgument(original_arg_count).getType()); - // Remove the old arguments. - for (int i = original_arg_count - 1; i >= 0; --i) { - entry_block.eraseArgument(i); - } - // Insert terminator at the end. - rewriter.setInsertionPointToEnd(&entry_block); - rewriter.create(loc); + auto return_op = cast(entry_block.getTerminator()); + auto result_type = return_op.results().front().getType().cast(); + sig_conversion.addInputs({MemRefType::get(result_type.getShape(), + result_type.getElementType())}); + rewriter.applySignatureConversion(&new_op.body(), sig_conversion); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); @@ -300,6 +291,12 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } }; +// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality +// is provided by mlir buffer assignment, so use the pattern from there. +// TODO(DFKI): Move this out of detail. +using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter< + mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>; + class HloToLhloTensorLoadOpConverter : public BaseOpConversion { public: @@ -312,7 +309,6 @@ class HloToLhloTensorLoadOpConverter } }; -// TODO(b/137624192): Rewrite into a copy and elide copy if possible. class HloToLhloTensorStoreOpConverter : public BaseOpConversion { public: @@ -506,6 +502,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloReduceOpConverter, + HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter >(context, bufferAssignment, converter); diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index aa5d800..789ab82 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -487,3 +487,26 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> return %out : tensor<3x5x5x4xf32> } + +// ----- + +// BOTH-LABEL: func @reduce +func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { + // BOTH: %[[OUT:.*]] = alloc() : memref<1xf32> + // BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { + // BOTH: ^bb0(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, + // BOTH-SAME: %[[ARG3:.*]]: memref): + // BOTH: %[[TMP:.*]] = alloc() : memref + // BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) + // BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) + // BOTH: "lmhlo.terminator"() : () -> () + // BOTH: }) {dimensions = dense<1> : tensor<1xi64>} + // BOTH-SAME: : (memref<1x8xf32>, memref, memref<1xf32>) -> () + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (tensor<1x8xf32>, tensor) -> tensor<1xf32> + return %0 : tensor<1xf32> +} diff --git a/tests/lhlo-copy-removal.mlir b/tests/lhlo-copy-removal.mlir index 6d7992c..3271595 100644 --- a/tests/lhlo-copy-removal.mlir +++ b/tests/lhlo-copy-removal.mlir @@ -91,3 +91,25 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>, dealloc %0 : memref<2x2xf32> "lmhlo.terminator"() : () -> () } + +// ----- + +// CHECK-LABEL: func @reduce +func @reduce(%arg0: memref<1x8xf32>, %arg1: memref, %arg2: memref<1xf32>) { + %0 = alloc() : memref<1xf32> + "lmhlo.reduce"(%arg0, %arg1, %0) ( { + // CHECK: ^bb0(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, + // CHECK-SAME: %[[ARG2:.*]]: memref) + ^bb0(%arg3: memref, %arg4: memref, %arg5: memref): + %1 = alloc() : memref + // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + "lmhlo.add"(%arg3, %arg4, %1) + : (memref, memref, memref) -> () + // CHECK-NOT; lmhlo.copy + "lmhlo.copy"(%1, %arg5) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (memref<1x8xf32>, memref, memref<1xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> () + return +}