Update stride test.

This commit is contained in:
Colin 2025-08-13 16:37:24 +08:00
parent ce81f26845
commit 59e079f5e7
2 changed files with 7 additions and 2 deletions

View File

@ -17,4 +17,9 @@
1. key[10] = 1000.0
2. 每一行数据像素表示一个新的token和前面所有token的关系
![alt text](q@k_seq_47_layer_0.png)
![alt text](q@k_seq_47_layer_0.png)
## 在样本的中间插入固定的token
1. 使用stride的方法在每个token的中间插入一个固定的无用的token
2. 插入的token用或者不用于计算loss对精度都没有提升

View File

@ -433,7 +433,7 @@ class MeaningDataset(Dataset):
val_mask, stride_mask = self.get_seq_mask_tensor(idx_list)
output["val_mask"] = val_mask
labels = data.clone()
labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
# labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
output["labels"] = labels
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
return output