diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 044755c..bb4c618 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -181,7 +181,8 @@ struct TensorTypeConverter : public TypeConverter { /// override needs to be removed in favour of the original MLIR one.] bool isSignatureLegal(FunctionType funcType) { return llvm::all_of( - funcType.getInputs(), [this](Type type) { return isLegal(type); }); + llvm::concat(funcType.getInputs(), funcType.getResults()), + [this](Type type) { return isLegal(type); }); } }; diff --git a/test/mlir/onnx/onnx_lowering.mlir b/test/mlir/onnx/onnx_lowering.mlir index 2757cf4..a9eff20 100644 --- a/test/mlir/onnx/onnx_lowering.mlir +++ b/test/mlir/onnx/onnx_lowering.mlir @@ -1,5 +1,21 @@ // RUN: onnx-mlir-opt --shape-inference --lower-frontend %s -split-input-file | FileCheck %s +// ---- + +func @test_no_argument_1() -> () { +} + +func @test_no_argument_2() -> tensor<*xf32> { + %0 = "onnx.Constant"() {value = dense<[[1.000000e+0, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32>} : () -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + +} + +// CHECK: test_no_argument_1 +// CHECK-NEXT: test_no_argument_2 +// CHECK: [[RES:%.+]] = "{{.*}}"({{.*}}) {{.*}} : ({{.*}}) -> memref<2x2xf32> +// CHECK: return [[RES]] : memref<2x2xf32> + // ----- func @test_add(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> {