68 lines
1.7 KiB
Python
68 lines
1.7 KiB
Python
|
import torch
|
|||
|
import numpy as np
|
|||
|
|
|||
|
#例5,向量内积
|
|||
|
A = torch.randn(10)
|
|||
|
B = torch.randn(10)
|
|||
|
#C=torch.dot(A,B)
|
|||
|
C = torch.einsum("i,i->",A,B)
|
|||
|
print("before:",A.shape, B.shape)
|
|||
|
print("after:",C.shape)
|
|||
|
|
|||
|
#例6,向量外积
|
|||
|
A = torch.randn(10)
|
|||
|
B = torch.randn(5)
|
|||
|
#C = torch.outer(A,B)
|
|||
|
C = torch.einsum("i,j->ij",A,B)
|
|||
|
print("before:",A.shape, B.shape)
|
|||
|
print("after:",C.shape)
|
|||
|
|
|||
|
#例7,矩阵乘法
|
|||
|
A = torch.randn(5,4)
|
|||
|
B = torch.randn(4,6)
|
|||
|
#C = torch.matmul(A,B)
|
|||
|
C = torch.einsum("ik,kj->ij",A,B)
|
|||
|
print("before:",A.shape, B.shape)
|
|||
|
print("after:",C.shape)
|
|||
|
|
|||
|
#例8,张量缩并
|
|||
|
A = torch.randn(3,4,5)
|
|||
|
B = torch.randn(4,3,6)
|
|||
|
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
|
|||
|
C = torch.einsum("ijk,jih->kh",A,B)
|
|||
|
print("before:",A.shape, B.shape)
|
|||
|
print("after:",C.shape)
|
|||
|
|
|||
|
|
|||
|
|
|||
|
a = torch.randn(2,3)
|
|||
|
b = torch.randn(5,3,7)
|
|||
|
c = torch.randn(2,7)
|
|||
|
# i = 2, k = 3, j = 5, l = 7
|
|||
|
torch_ein_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy()
|
|||
|
m = torch.nn.Bilinear(3, 7, 5, bias=False)
|
|||
|
m.weight.data = b
|
|||
|
torch_org_out = m(a, c).detach().numpy()
|
|||
|
|
|||
|
np_a = a.numpy()
|
|||
|
np_b = b.numpy()
|
|||
|
np_c = c.numpy()
|
|||
|
np_out = np.empty((2, 5), dtype=np.float32)
|
|||
|
# 自由索引外循环 这里是 i 和 j
|
|||
|
for i in range(0, 2):
|
|||
|
for j in range(0, 5):
|
|||
|
# 求和索引内循环 这里是 k 和 l
|
|||
|
sum_result = 0
|
|||
|
for k in range(0, 3):
|
|||
|
for l in range(0, 7):
|
|||
|
sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l]
|
|||
|
np_out[i, j] = sum_result
|
|||
|
|
|||
|
# print("matrix a:\n", np_a)
|
|||
|
# print("matrix b:\n", np_b)
|
|||
|
print("torch ein out: \n", torch_ein_out)
|
|||
|
print("torch org out: \n", torch_org_out)
|
|||
|
print("numpy out: \n", np_out)
|
|||
|
print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
|
|||
|
print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))
|