From c44e08351d8c9e935c252b6cb8ef102dd25036bd Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Thu, 16 Jul 2020 11:40:32 +0000 Subject: [PATCH] Fix mhlo to lmhlo conversion for ReduceOp. The existing conversion no longer worked and was not save to undo. Furthermore, the pattern for mhlo.return had been removed. Also adds some tests to ensure this does not degrade again. PiperOrigin-RevId: 321542071 --- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 39 +++++++++---------- tests/hlo-legalize-to-lhlo.mlir | 23 +++++++++++ tests/lhlo-copy-removal.mlir | 22 +++++++++++ 3 files changed, 63 insertions(+), 21 deletions(-) diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 8425f0f..d14072a 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -42,9 +42,6 @@ namespace { template using BaseOpConversion = BufferAssignmentOpConversionPattern; -using StdReturnOpConverter = - detail::BufferAssignmentReturnOpConverter; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -272,27 +269,21 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); - // Create new block arguments with correct type. + // Convert the region signature to memref and add extra result. auto& entry_block = new_op.body().front(); - int original_arg_count = entry_block.getNumArguments(); - for (int i = 0; i < original_arg_count; ++i) { - auto old_arg = entry_block.getArgument(i); - auto old_type = old_arg.getType().cast(); + TypeConverter::SignatureConversion sig_conversion( + entry_block.getNumArguments() + 1); + for (auto arg : entry_block.getArguments()) { + auto old_type = arg.getType().cast(); auto new_type = MemRefType::get(old_type.getShape(), old_type.getElementType()); - auto new_arg = entry_block.addArgument(new_type); - rewriter.replaceUsesOfBlockArgument(old_arg, new_arg); + sig_conversion.addInputs(arg.getArgNumber(), new_type); } - // Add an argument for the result. - entry_block.addArgument( - entry_block.getArgument(original_arg_count).getType()); - // Remove the old arguments. - for (int i = original_arg_count - 1; i >= 0; --i) { - entry_block.eraseArgument(i); - } - // Insert terminator at the end. - rewriter.setInsertionPointToEnd(&entry_block); - rewriter.create(loc); + auto return_op = cast(entry_block.getTerminator()); + auto result_type = return_op.results().front().getType().cast(); + sig_conversion.addInputs({MemRefType::get(result_type.getShape(), + result_type.getElementType())}); + rewriter.applySignatureConversion(&new_op.body(), sig_conversion); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); @@ -300,6 +291,12 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } }; +// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality +// is provided by mlir buffer assignment, so use the pattern from there. +// TODO(DFKI): Move this out of detail. +using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter< + mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>; + class HloToLhloTensorLoadOpConverter : public BaseOpConversion { public: @@ -312,7 +309,6 @@ class HloToLhloTensorLoadOpConverter } }; -// TODO(b/137624192): Rewrite into a copy and elide copy if possible. class HloToLhloTensorStoreOpConverter : public BaseOpConversion { public: @@ -506,6 +502,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloReduceOpConverter, + HloToLhloReturnOpConverter, HloToLhloTensorLoadOpConverter, HloToLhloTensorStoreOpConverter >(context, bufferAssignment, converter); diff --git a/tests/hlo-legalize-to-lhlo.mlir b/tests/hlo-legalize-to-lhlo.mlir index aa5d800..789ab82 100644 --- a/tests/hlo-legalize-to-lhlo.mlir +++ b/tests/hlo-legalize-to-lhlo.mlir @@ -487,3 +487,26 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> return %out : tensor<3x5x5x4xf32> } + +// ----- + +// BOTH-LABEL: func @reduce +func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor) -> tensor<1xf32> { + // BOTH: %[[OUT:.*]] = alloc() : memref<1xf32> + // BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( { + // BOTH: ^bb0(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref, + // BOTH-SAME: %[[ARG3:.*]]: memref): + // BOTH: %[[TMP:.*]] = alloc() : memref + // BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]]) + // BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]]) + // BOTH: "lmhlo.terminator"() : () -> () + // BOTH: }) {dimensions = dense<1> : tensor<1xi64>} + // BOTH-SAME: : (memref<1x8xf32>, memref, memref<1xf32>) -> () + %0 = "mhlo.reduce"(%arg0, %arg1) ( { + ^bb0(%arg2: tensor, %arg3: tensor): // no predecessors + %1 = mhlo.add %arg2, %arg3 : tensor + "mhlo.return"(%1) : (tensor) -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (tensor<1x8xf32>, tensor) -> tensor<1xf32> + return %0 : tensor<1xf32> +} diff --git a/tests/lhlo-copy-removal.mlir b/tests/lhlo-copy-removal.mlir index 6d7992c..3271595 100644 --- a/tests/lhlo-copy-removal.mlir +++ b/tests/lhlo-copy-removal.mlir @@ -91,3 +91,25 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>, dealloc %0 : memref<2x2xf32> "lmhlo.terminator"() : () -> () } + +// ----- + +// CHECK-LABEL: func @reduce +func @reduce(%arg0: memref<1x8xf32>, %arg1: memref, %arg2: memref<1xf32>) { + %0 = alloc() : memref<1xf32> + "lmhlo.reduce"(%arg0, %arg1, %0) ( { + // CHECK: ^bb0(%[[ARG0:.*]]: memref, %[[ARG1:.*]]: memref, + // CHECK-SAME: %[[ARG2:.*]]: memref) + ^bb0(%arg3: memref, %arg4: memref, %arg5: memref): + %1 = alloc() : memref + // CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + "lmhlo.add"(%arg3, %arg4, %1) + : (memref, memref, memref) -> () + // CHECK-NOT; lmhlo.copy + "lmhlo.copy"(%1, %arg5) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () + }) {dimensions = dense<1> : tensor<1xi64>} + : (memref<1x8xf32>, memref, memref<1xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> () + return +}