PR #43069: [hlo] Unbreak hlo-legalize-to-lhlo test
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/43069 The legalization of mlhlo.ReturnOp to lhlo.TerminatorOp by using BufferAssignmentReturnOpConverter fails since the Memref typed results (or the Memref typed operands of Return operation) are set to stay as results after legalization but lhlo.TerminatorOp doesn't accept any operands. Therefore, BufferAssignmentReturnOpConverter must be replaced with a manual conversion that removes all operands of mlhlo.ReturnOp and inserts copy operations in their places. Copybara import of the project: -- 8be0435b0147263c3872bedec58fd215f784b450 by Ehsan Toosi <ehsan.nadjaran_toosi@dfki.de>: [hlo] Unbreak hlo-legalize-to-lhlo test COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/43069 from dfki-ehna:fix_hlo_legalize_to_lhlo_test 8be0435b0147263c3872bedec58fd215f784b450 PiperOrigin-RevId: 330907602
This commit is contained in:
parent
f46ba09653
commit
d599485e06
|
@ -287,11 +287,36 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality
|
||||
// is provided by mlir buffer assignment, so use the pattern from there.
|
||||
using HloToLhloReturnOpConverter =
|
||||
BufferAssignmentReturnOpConverter<mhlo::ReturnOp, lmhlo::TerminatorOp,
|
||||
lmhlo::CopyOp>;
|
||||
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator.
|
||||
struct HloToLhloReturnOpConverter : public BaseOpConversion<mhlo::ReturnOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::ReturnOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ReturnOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
auto& entry_block = op.getParentRegion()->front();
|
||||
auto num_arguments = entry_block.getNumArguments();
|
||||
if (operands.size() > num_arguments) {
|
||||
return op.emitError(
|
||||
"The number of operands that need Copy operations is more "
|
||||
"than the number of target function arguments.");
|
||||
}
|
||||
|
||||
// The index of the first output block argument.
|
||||
auto dest_arg_idx = num_arguments - operands.size();
|
||||
|
||||
// Create a lmhlo.copy for each operand of mhlo.return.
|
||||
for (Value operand : operands) {
|
||||
rewriter.create<lmhlo::CopyOp>(loc, operand,
|
||||
entry_block.getArgument(dest_arg_idx));
|
||||
++dest_arg_idx;
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<lmhlo::TerminatorOp>(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class HloToLhloTensorLoadOpConverter
|
||||
: public BaseOpConversion<mlir::TensorLoadOp> {
|
||||
|
@ -429,6 +454,13 @@ struct HloLegalizeToLhlo
|
|||
isMemRefType);
|
||||
});
|
||||
|
||||
auto kind = results_escape_function
|
||||
? BufferAssignmentTypeConverter::KeepAsFunctionResult
|
||||
: BufferAssignmentTypeConverter::AppendToArgumentsList;
|
||||
converter.setResultConversionKind<UnrankedTensorType, UnrankedMemRefType>(
|
||||
kind);
|
||||
converter.setResultConversionKind<RankedTensorType, MemRefType>(kind);
|
||||
|
||||
populateHLOToLHLOConversionPattern(&context, &converter, &patterns);
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp>(&context, &converter,
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=PRE,BOTH %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck --check-prefixes=ESC,BOTH %s
|
||||
// TODO(herhut): unbreak the test after upstream API changes.
|
||||
// XFAIL: *
|
||||
|
||||
|
||||
// BOTH-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
|
Loading…
Reference in New Issue