Fix mhlo to lmhlo conversion for ReduceOp.

The existing conversion no longer worked and was not save to undo. Furthermore, the pattern for mhlo.return had been removed.

Also adds some tests to ensure this does not degrade again.

PiperOrigin-RevId: 321542071
This commit is contained in:
Stephan Herhut 2020-07-16 11:40:32 +00:00 committed by Mehdi Amini
parent c8bb0ff54d
commit c44e08351d
3 changed files with 63 additions and 21 deletions

View File

@ -42,9 +42,6 @@ namespace {
template <typename T> template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>; using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
using StdReturnOpConverter =
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
lmhlo::CopyOp, true>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand, Value shape_operand,
@ -272,27 +269,21 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
// Copy over the operations inside the region. // Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
// Create new block arguments with correct type. // Convert the region signature to memref and add extra result.
auto& entry_block = new_op.body().front(); auto& entry_block = new_op.body().front();
int original_arg_count = entry_block.getNumArguments(); TypeConverter::SignatureConversion sig_conversion(
for (int i = 0; i < original_arg_count; ++i) { entry_block.getNumArguments() + 1);
auto old_arg = entry_block.getArgument(i); for (auto arg : entry_block.getArguments()) {
auto old_type = old_arg.getType().cast<TensorType>(); auto old_type = arg.getType().cast<TensorType>();
auto new_type = auto new_type =
MemRefType::get(old_type.getShape(), old_type.getElementType()); MemRefType::get(old_type.getShape(), old_type.getElementType());
auto new_arg = entry_block.addArgument(new_type); sig_conversion.addInputs(arg.getArgNumber(), new_type);
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
} }
// Add an argument for the result. auto return_op = cast<mhlo::ReturnOp>(entry_block.getTerminator());
entry_block.addArgument( auto result_type = return_op.results().front().getType().cast<TensorType>();
entry_block.getArgument(original_arg_count).getType()); sig_conversion.addInputs({MemRefType::get(result_type.getShape(),
// Remove the old arguments. result_type.getElementType())});
for (int i = original_arg_count - 1; i >= 0; --i) { rewriter.applySignatureConversion(&new_op.body(), sig_conversion);
entry_block.eraseArgument(i);
}
// Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<lmhlo::TerminatorOp>(loc);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size())); rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
@ -300,6 +291,12 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
} }
}; };
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality
// is provided by mlir buffer assignment, so use the pattern from there.
// TODO(DFKI): Move this out of detail.
using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter<
mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>;
class HloToLhloTensorLoadOpConverter class HloToLhloTensorLoadOpConverter
: public BaseOpConversion<mlir::TensorLoadOp> { : public BaseOpConversion<mlir::TensorLoadOp> {
public: public:
@ -312,7 +309,6 @@ class HloToLhloTensorLoadOpConverter
} }
}; };
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloTensorStoreOpConverter class HloToLhloTensorStoreOpConverter
: public BaseOpConversion<mlir::TensorStoreOp> { : public BaseOpConversion<mlir::TensorStoreOp> {
public: public:
@ -506,6 +502,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo::SubOp>, HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>, HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloReduceOpConverter, HloToLhloReduceOpConverter,
HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter HloToLhloTensorStoreOpConverter
>(context, bufferAssignment, converter); >(context, bufferAssignment, converter);

View File

@ -487,3 +487,26 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32> } : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
return %out : tensor<3x5x5x4xf32> return %out : tensor<3x5x5x4xf32>
} }
// -----
// BOTH-LABEL: func @reduce
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// BOTH: %[[OUT:.*]] = alloc() : memref<1xf32>
// BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// BOTH: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// BOTH-SAME: %[[ARG3:.*]]: memref<f32>):
// BOTH: %[[TMP:.*]] = alloc() : memref<f32>
// BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// BOTH: "lmhlo.terminator"() : () -> ()
// BOTH: }) {dimensions = dense<1> : tensor<1xi64>}
// BOTH-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (tensor<1x8xf32>, tensor<f32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}

View File

@ -91,3 +91,25 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
dealloc %0 : memref<2x2xf32> dealloc %0 : memref<2x2xf32>
"lmhlo.terminator"() : () -> () "lmhlo.terminator"() : () -> ()
} }
// -----
// CHECK-LABEL: func @reduce
func @reduce(%arg0: memref<1x8xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) {
%0 = alloc() : memref<1xf32>
"lmhlo.reduce"(%arg0, %arg1, %0) ( {
// CHECK: ^bb0(%[[ARG0:.*]]: memref<f32>, %[[ARG1:.*]]: memref<f32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<f32>)
^bb0(%arg3: memref<f32>, %arg4: memref<f32>, %arg5: memref<f32>):
%1 = alloc() : memref<f32>
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
"lmhlo.add"(%arg3, %arg4, %1)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK-NOT; lmhlo.copy
"lmhlo.copy"(%1, %arg5) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> ()
return
}