diff --git a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 83f60c5..8751ab4 100644 --- a/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -391,16 +391,15 @@ struct HloLegalizeToLhlo target.addIllegalDialect(); BufferAssignmentTypeConverter converter; + auto isMemRefType = [](Type type) { return type.isa(); }; target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); - return llvm::all_of(inputs, - [](Type input) { return input.isa(); }) && + return llvm::all_of(inputs, isMemRefType) && converter.isLegal(&op.getBody()); }); target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { return std::all_of(returnOp.operand_type_begin(), - returnOp.operand_type_end(), - [](Type type) { return type.isa(); }); + returnOp.operand_type_end(), isMemRefType); }); auto module = getOperation(); diff --git a/tests/hlo-legalize-to-lhlo-unranked.mlir b/tests/hlo-legalize-to-lhlo-unranked.mlir new file mode 100644 index 0000000..063716a --- /dev/null +++ b/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s + +// CHECK-LABEL: func @func_op_unranked_arg_result +func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { + return %arg0 : tensor<*xf32> +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> +// CHECK-NEXT: return [[ARG]] : memref<*xf32>