PR #43069: [hlo] Unbreak hlo-legalize-to-lhlo test

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/43069

The legalization of mlhlo.ReturnOp to lhlo.TerminatorOp by using BufferAssignmentReturnOpConverter fails since the Memref typed results (or the Memref typed operands of Return operation) are set to stay as results after legalization but lhlo.TerminatorOp doesn't accept any operands. Therefore, BufferAssignmentReturnOpConverter must be replaced with a manual conversion that removes all operands of mlhlo.ReturnOp and inserts copy operations in their places.
Copybara import of the project:

--
8be0435b0147263c3872bedec58fd215f784b450 by Ehsan Toosi <ehsan.nadjaran_toosi@dfki.de>:

[hlo] Unbreak hlo-legalize-to-lhlo test

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/43069 from dfki-ehna:fix_hlo_legalize_to_lhlo_test 8be0435b0147263c3872bedec58fd215f784b450
PiperOrigin-RevId: 330907602
This commit is contained in:
Ehsan Toosi 2020-09-10 04:07:57 -07:00 committed by TensorFlow MLIR Team
parent f46ba09653
commit d599485e06
2 changed files with 37 additions and 8 deletions

View File

@ -287,11 +287,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
} }
}; };
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality // Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
// is provided by mlir buffer assignment, so use the pattern from there. struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
using HloToLhloReturnOpConverter = public:
BufferAssignmentReturnOpConverter<mhlo::ReturnOp, lmhlo::TerminatorOp, using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
lmhlo::CopyOp>;
LogicalResult matchAndRewrite(
mhlo::ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
auto& entry_block = op.getParentRegion()->front();
auto num_arguments = entry_block.getNumArguments();
if (operands.size() > num_arguments) {
return op.emitError(
"The number of operands that need Copy operations is more "
"than the number of target function arguments.");
}
// The index of the first output block argument.
auto dest_arg_idx = num_arguments - operands.size();
// Create a lmhlo.copy for each operand of mhlo.return.
for (Value operand : operands) {
rewriter.create<lmhlo::CopyOp>(loc, operand,
entry_block.getArgument(dest_arg_idx));
++dest_arg_idx;
}
rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
return success();
}
};
class HloToLhloTensorLoadOpConverter class HloToLhloTensorLoadOpConverter
: public BaseOpConversion<mlir::TensorLoadOp> { : public BaseOpConversion<mlir::TensorLoadOp> {
@ -429,6 +454,13 @@ struct HloLegalizeToLhlo
isMemRefType); isMemRefType);
}); });
auto kind = results_escape_function
? BufferAssignmentTypeConverter::KeepAsFunctionResult
: BufferAssignmentTypeConverter::AppendToArgumentsList;
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
kind);
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
populateHLOToLHLOConversionPattern(&context, &converter, &patterns); populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
populateWithBufferAssignmentOpConversionPatterns< populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter, mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,

View File

@ -1,8 +1,5 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
// TODO(herhut): unbreak the test after upstream API changes.
// XFAIL: *
// BOTH-LABEL: func @attrs // BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {