96 lines
2.1 KiB
Python
96 lines
2.1 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# 假设输入是一个 batch 的序列数据,形状为 (batch_size, seq_len, hidden_dim)
|
|
batch_size, seq_len, hidden_dim = 2, 2, 4
|
|
input_tensor = torch.randn(batch_size, seq_len, hidden_dim)
|
|
|
|
# 定义 LayerNorm 层
|
|
layer_norm1 = nn.LayerNorm(hidden_dim)
|
|
layer_norm2 = nn.LayerNorm(hidden_dim)
|
|
|
|
# 应用 LayerNorm
|
|
output = layer_norm1(input_tensor)
|
|
print(input_tensor.numpy())
|
|
print("\n")
|
|
print("\n")
|
|
print(output.detach().numpy())
|
|
|
|
output = layer_norm2(output)
|
|
print("\n")
|
|
print("\n")
|
|
print(output.detach().numpy())
|
|
|
|
x1 = torch.empty((1, 7, 768), dtype=float)
|
|
time_shift = nn.ZeroPad2d((1, 1, -1, 1))
|
|
xx = time_shift(x1)
|
|
|
|
|
|
x1 = torch.tensor([[1, 2]], dtype=float)
|
|
x2 = torch.tensor([[5, 6], [7, 8]], dtype=float)
|
|
|
|
y = x1 @ x2 # torch.matmul(x1 , x2)
|
|
x_inverse = torch.inverse(x2, out=None)
|
|
y_inverse = y @ x_inverse
|
|
y_inverse = y_inverse.permute(1, 0)
|
|
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]], dtype=float)
|
|
print(x)
|
|
print("x.tile((2)) -> ", x.tile((2)).shape)
|
|
print(x.tile((2)))
|
|
|
|
print()
|
|
print("x.tile((2, 1)) -> ", x.tile((2, 1)).shape)
|
|
print(x.tile((2, 1)))
|
|
|
|
print()
|
|
print("x.tile((2, 1, 2)) -> ", x.tile((2, 1, 2)).shape)
|
|
print(x.tile((2, 1, 2)))
|
|
|
|
print()
|
|
print("x.tile((2, 1, 1)) -> ", x.tile((2, 1, 1)).shape)
|
|
print(x.tile((2, 1, 1)))
|
|
|
|
print()
|
|
y = torch.tensor([[2, 1], [3, 4]])
|
|
print(y.ne(x))
|
|
|
|
print()
|
|
print(x.prod(1))
|
|
print(x.prod(0))
|
|
|
|
print()
|
|
print(x.unsqueeze(1).shape)
|
|
print(x.unsqueeze(1).squeeze(1).shape)
|
|
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]]).to(float)
|
|
print(x.mean(1))
|
|
print(x.mean(0))
|
|
print(x.mean(0, keepdim=True))
|
|
|
|
print()
|
|
print()
|
|
x = torch.tensor([[1, 2], [3, 4]])
|
|
print(x.flatten(0))
|
|
|
|
x = torch.tensor([[1, 2], [3, 4]])
|
|
print(torch.stack((x, x), 1))
|
|
print(torch.cat((x, x), 1))
|
|
# So if A and B are of shape (3, 4):
|
|
# torch.cat([A, B], dim=0) will be of shape (6, 4)
|
|
# torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
|
|
|
|
x = torch.ones([1, 32, 6, 128])
|
|
y = torch.ones([1, 32, 128, 6])
|
|
z = torch.ones([1, 32, 6, 128])
|
|
att = torch.matmul(x, y)
|
|
mm = torch.matmul(att, z)
|
|
print(mm.shape)
|