Witllm/test/embedding.py

29 lines
703 B
Python
Raw Normal View History

2023-12-22 18:01:57 +08:00
import torch
import torch.nn as nn
# 定义词表大小和向量维度
vocab_size = 10000
embedding_dim = 16
# 定义一个Embedding层
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
# 定义一个输入张量,形状为(batch_size, sequence_length)
input_tensor = torch.LongTensor([[1, 2], [4, 3]])
# 将输入张量传递给Embedding层
embedded_tensor = embedding(input_tensor)
print("embedded weight shape:")
print(embedding.weight.shape)
print("embedded weight:")
print(embedding.weight)
# 输出形状为 (batch_size, sequence_length, embedding_dim)
print("embedded out shape:")
print(embedded_tensor.shape)
print("embedded out:")
print(embedded_tensor)