Witllm/tensor.py

37 lines
672 B
Python
Raw Normal View History

2023-12-22 18:57:16 +08:00
import torch
x = torch.tensor([[1, 2], [3, 4]])
print(x)
print("x.tile((2)) -> ", x.tile((2)).shape)
print(x.tile((2)))
print()
print("x.tile((2, 1)) -> ", x.tile((2, 1)).shape)
print(x.tile((2, 1)))
print()
print("x.tile((2, 1, 2)) -> ", x.tile((2, 1, 2)).shape)
print(x.tile((2, 1, 2)))
print()
print("x.tile((2, 1, 1)) -> ", x.tile((2, 1, 1)).shape)
print(x.tile((2, 1, 1)))
print()
y = torch.tensor([[2, 1], [3, 4]])
print(y.ne(x))
print()
print(x.prod(1))
print(x.prod(0))
print()
print(x.unsqueeze(1).shape)
print(x.unsqueeze(1).squeeze(1).shape)
2023-12-22 20:01:09 +08:00
x = torch.tensor([[1, 2], [3, 4]]).to(float)
print(x.mean(1))
print(x.mean(0))
print(x.mean(0, keepdim=True))