[MLIR] Convert FuncOp signature with unranked types in HLO->LHLO conversion.

PiperOrigin-RevId: 320146856
This commit is contained in:
Alexander Belyaev 2020-07-08 08:43:30 +00:00 committed by Mehdi Amini
parent e1651b6090
commit 8692fde3f9
2 changed files with 11 additions and 4 deletions

View File

@ -391,16 +391,15 @@ struct HloLegalizeToLhlo
target.addIllegalDialect<mhlo::XlaHloDialect>();
BufferAssignmentTypeConverter converter;
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs();
return llvm::all_of(inputs,
[](Type input) { return input.isa<MemRefType>(); }) &&
return llvm::all_of(inputs, isMemRefType) &&
converter.isLegal(&op.getBody());
});
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
return std::all_of(returnOp.operand_type_begin(),
returnOp.operand_type_end(),
[](Type type) { return type.isa<MemRefType>(); });
returnOp.operand_type_end(), isMemRefType);
});
auto module = getOperation();

View File

@ -0,0 +1,8 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s
// CHECK-LABEL: func @func_op_unranked_arg_result
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %arg0 : tensor<*xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
// CHECK-NEXT: return [[ARG]] : memref<*xf32>