29 lines
703 B
Python
29 lines
703 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
# 定义词表大小和向量维度
|
|
vocab_size = 10000
|
|
embedding_dim = 16
|
|
|
|
# 定义一个Embedding层
|
|
embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
|
|
|
|
# 定义一个输入张量,形状为(batch_size, sequence_length)
|
|
input_tensor = torch.LongTensor([[1, 2], [4, 3]])
|
|
|
|
# 将输入张量传递给Embedding层
|
|
embedded_tensor = embedding(input_tensor)
|
|
|
|
|
|
print("embedded weight shape:")
|
|
print(embedding.weight.shape)
|
|
print("embedded weight:")
|
|
print(embedding.weight)
|
|
|
|
|
|
# 输出形状为 (batch_size, sequence_length, embedding_dim)
|
|
print("embedded out shape:")
|
|
print(embedded_tensor.shape)
|
|
print("embedded out:")
|
|
print(embedded_tensor)
|