[MLIR] Convert FuncOp signature with unranked types in HLO->LHLO conversion.
PiperOrigin-RevId: 320146856
This commit is contained in:
parent
e1651b6090
commit
8692fde3f9
|
@ -391,16 +391,15 @@ struct HloLegalizeToLhlo
|
|||
target.addIllegalDialect<mhlo::XlaHloDialect>();
|
||||
|
||||
BufferAssignmentTypeConverter converter;
|
||||
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
auto inputs = op.getType().getInputs();
|
||||
return llvm::all_of(inputs,
|
||||
[](Type input) { return input.isa<MemRefType>(); }) &&
|
||||
return llvm::all_of(inputs, isMemRefType) &&
|
||||
converter.isLegal(&op.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
||||
return std::all_of(returnOp.operand_type_begin(),
|
||||
returnOp.operand_type_end(),
|
||||
[](Type type) { return type.isa<MemRefType>(); });
|
||||
returnOp.operand_type_end(), isMemRefType);
|
||||
});
|
||||
|
||||
auto module = getOperation();
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @func_op_unranked_arg_result
|
||||
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
|
||||
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
|
Loading…
Reference in New Issue