Witllm/test/einsum.py

68 lines
1.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
#例5向量内积
A = torch.randn(10)
B = torch.randn(10)
#C=torch.dot(A,B)
C = torch.einsum("i,i->",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
#例6向量外积
A = torch.randn(10)
B = torch.randn(5)
#C = torch.outer(A,B)
C = torch.einsum("i,j->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
#例7矩阵乘法
A = torch.randn(5,4)
B = torch.randn(4,6)
#C = torch.matmul(A,B)
C = torch.einsum("ik,kj->ij",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
#例8张量缩并
A = torch.randn(3,4,5)
B = torch.randn(4,3,6)
#C = torch.tensordot(A,B,dims=[(0,1),(1,0)])
C = torch.einsum("ijk,jih->kh",A,B)
print("before:",A.shape, B.shape)
print("after:",C.shape)
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
# i = 2, k = 3, j = 5, l = 7
torch_ein_out = torch.einsum('ik,jkl,il->ij', [a, b, c]).numpy()
m = torch.nn.Bilinear(3, 7, 5, bias=False)
m.weight.data = b
torch_org_out = m(a, c).detach().numpy()
np_a = a.numpy()
np_b = b.numpy()
np_c = c.numpy()
np_out = np.empty((2, 5), dtype=np.float32)
# 自由索引外循环 这里是 i 和 j
for i in range(0, 2):
for j in range(0, 5):
# 求和索引内循环 这里是 k 和 l
sum_result = 0
for k in range(0, 3):
for l in range(0, 7):
sum_result += np_a[i, k] * np_b[j, k, l] * np_c[i, l]
np_out[i, j] = sum_result
# print("matrix a:\n", np_a)
# print("matrix b:\n", np_b)
print("torch ein out: \n", torch_ein_out)
print("torch org out: \n", torch_org_out)
print("numpy out: \n", np_out)
print("is np_out == torch_ein_out ?", np.allclose(torch_ein_out, np_out))
print("is torch_org_out == torch_ein_out ?", np.allclose(torch_ein_out, torch_org_out))