diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 7b47e60..3e4eb4a 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -40,4 +40,14 @@ func @test_matmul_3(%arg0 : tensor<64x42xf32>, %arg1 : tensor<16x?x42x32xf32>) - // CHECK-LABEL: test_matmul_3 // CHECK: [[RES3:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<16x?x42x32xf32>) -> tensor<16x?x64x32xf32> // CHECK: return [[RES3]] : tensor<16x?x64x32xf32> +} + +/// MatMul: 2-D x K-D (K > 2) +func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor<*xf32> + "std.return"(%0) : (tensor<*xf32>) -> () + + // CHECK-LABEL: test_matmul_4 + // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor) -> tensor + // CHECK: return [[RES4]] : tensor } \ No newline at end of file