Add test for multypling stacks of matrices.
This commit is contained in:
		
							parent
							
								
									ae966cdee9
								
							
						
					
					
						commit
						95ebf3e23a
					
				| 
						 | 
					@ -50,4 +50,14 @@ func @test_matmul_4(%arg0 : tensor<64x42xf32>, %arg1 : tensor<?x?x?x?xf32>) -> t
 | 
				
			||||||
  // CHECK-LABEL: test_matmul_4
 | 
					  // CHECK-LABEL: test_matmul_4
 | 
				
			||||||
  // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x64x?xf32>
 | 
					  // CHECK: [[RES4:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<64x42xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x64x?xf32>
 | 
				
			||||||
  // CHECK: return [[RES4]] : tensor<?x?x64x?xf32>
 | 
					  // CHECK: return [[RES4]] : tensor<?x?x64x?xf32>
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// MatMul: K1-D x K2-D (K1 > 2, K2 > 2)
 | 
				
			||||||
 | 
					func @test_matmul_5(%arg0 : tensor<16x?x64x42xf32>, %arg1 : tensor<32x?x64x42x32xf32>) -> tensor<*xf32> {
 | 
				
			||||||
 | 
					  %0 = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<*xf32>
 | 
				
			||||||
 | 
					  "std.return"(%0) : (tensor<*xf32>) -> ()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // CHECK-LABEL: test_matmul_5
 | 
				
			||||||
 | 
					  // CHECK: [[RES5:%.+]] = "onnx.MatMul"(%arg0, %arg1) : (tensor<16x?x64x42xf32>, tensor<32x?x64x42x32xf32>) -> tensor<32x16x64x64x32xf32>
 | 
				
			||||||
 | 
					  // CHECK: return [[RES5]] : tensor<32x16x64x64x32xf32>
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue