diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 7c44535..64f1533 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -53,11 +53,11 @@ func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor) -> t } /// MatMul: K1-D x K2-D (K1 > 2, K2 > 2) -func @test_matmul_5(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> { - %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32> +func @test_matmul_5(%arg0 : tensor<16x?x?x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> { + %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32> "std.return"(%0) : (tensor<*xf32>) -> () // CHECK-LABEL: test_matmul_5 - // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x64x32xf32> - // CHECK: return [[RES5]] : tensor<32x16x64x64x32xf32> + // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x?x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x?x32xf32> + // CHECK: return [[RES5]] : tensor<32x16x64x?x32xf32> } \ No newline at end of file