Witllm/test/tensor.py

59 lines
1.2 KiB
Python

import torch
import torch.nn.functional as F
x = torch.tensor([[1, 2], [3, 4]])
print(x)
print("x.tile((2)) -> ", x.tile((2)).shape)
print(x.tile((2)))
print()
print("x.tile((2, 1)) -> ", x.tile((2, 1)).shape)
print(x.tile((2, 1)))
print()
print("x.tile((2, 1, 2)) -> ", x.tile((2, 1, 2)).shape)
print(x.tile((2, 1, 2)))
print()
print("x.tile((2, 1, 1)) -> ", x.tile((2, 1, 1)).shape)
print(x.tile((2, 1, 1)))
print()
y = torch.tensor([[2, 1], [3, 4]])
print(y.ne(x))
print()
print(x.prod(1))
print(x.prod(0))
print()
print(x.unsqueeze(1).shape)
print(x.unsqueeze(1).squeeze(1).shape)
x = torch.tensor([[1, 2], [3, 4]]).to(float)
print(x.mean(1))
print(x.mean(0))
print(x.mean(0, keepdim=True))
print()
print()
x = torch.tensor([[1, 2], [3, 4]])
print(x.flatten(0))
x = torch.tensor([[1, 2], [3, 4]])
print(torch.stack((x, x), 1))
print(torch.cat((x, x), 1))
# So if A and B are of shape (3, 4):
# torch.cat([A, B], dim=0) will be of shape (6, 4)
# torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
x = torch.ones([1, 32, 6, 128])
y = torch.ones([1, 32, 128, 6])
z = torch.ones([1, 32, 6, 128])
att = torch.matmul(x, y)
mm = torch.matmul(att, z)
print(mm.shape)