68 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			68 lines
		
	
	
		
			1.7 KiB
		
	
	
	
		
			Python
		
	
	
	
| import torch
 | ||
| import numpy as np
 | ||
| 
 | ||
| #例5,向量内积
 | ||
| A = torch.randn(10)
 | ||
| B = torch.randn(10)
 | ||
| #C=torch.dot(A,B)
 | ||
| C = torch.einsum("i,i->",A,B)
 | ||
| print("before:",A.shape, B.shape)
 | ||
| print("after:",C.shape)
 | ||
| 
 | ||
| #例6,向量外积
 | ||
| A = torch.randn(10)
 | ||
| B = torch.randn(5)
 | ||
| #C = torch.outer(A,B)
 | ||
| C = torch.einsum("i,j->ij",A,B)
 | ||
| print("before:",A.shape, B.shape)
 | ||
| print("after:",C.shape)
 | ||
| 
 | ||
| #例7,矩阵乘法
 | ||
| A = torch.randn(5,4)
 | ||
| B = torch.randn(4,6)
 | ||
| #C = torch.matmul(A,B)
 | ||
| C = torch.einsum("ik,kj->ij",A,B)
 | ||
| print("before:",A.shape, B.shape)
 | ||
| print("after:",C.shape)
 | ||
| 
 | ||
| #例8,张量缩并
 | ||
| A = torch.randn(3,4,5)
 | ||
| B = torch.randn(4,3,6)
 | ||
| #C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
 | ||
| C = torch.einsum("ijk,jih->kh",A,B)
 | ||
| print("before:",A.shape, B.shape)
 | ||
| print("after:",C.shape)
 | ||
| 
 | ||
| 
 | ||
| 
 | ||
| a = torch.randn(2,3)
 | ||
| b = torch.randn(5,3,7)
 | ||
| c = torch.randn(2,7)
 | ||
| # i = 2, k = 3, j = 5, l = 7
 | ||
| torch_ein_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy()
 | ||
| m = torch.nn.Bilinear(3, 7, 5, bias=False)
 | ||
| m.weight.data = b
 | ||
| torch_org_out = m(a, c).detach().numpy()
 | ||
| 
 | ||
| np_a = a.numpy()
 | ||
| np_b = b.numpy()
 | ||
| np_c = c.numpy()
 | ||
| np_out = np.empty((2, 5), dtype=np.float32)
 | ||
| # 自由索引外循环  这里是 i 和 j
 | ||
| for i in range(0, 2):
 | ||
|     for j in range(0, 5):
 | ||
|         # 求和索引内循环  这里是 k 和 l
 | ||
|         sum_result = 0
 | ||
|         for k in range(0, 3):
 | ||
|             for l in range(0, 7):
 | ||
|                 sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l]
 | ||
|         np_out[i, j] = sum_result
 | ||
| 
 | ||
| # print("matrix a:\n", np_a)
 | ||
| # print("matrix b:\n", np_b)
 | ||
| print("torch ein out: \n", torch_ein_out)
 | ||
| print("torch org out: \n", torch_org_out)
 | ||
| print("numpy out: \n", np_out)
 | ||
| print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
 | ||
| print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
 |