Update stride test.
This commit is contained in:
parent
ce81f26845
commit
59e079f5e7
|
@ -17,4 +17,9 @@
|
|||
1. key[10] = 1000.0
|
||||
2. 每一行数据(像素)表示一个新的token,和前面所有token的关系
|
||||
|
||||

|
||||

|
||||
|
||||
## 在样本的中间插入固定的token
|
||||
|
||||
1. 使用stride的方法,在每个token的中间插入一个固定的无用的token
|
||||
2. 插入的token用或者不用于计算loss,对精度都没有提升
|
|
@ -433,7 +433,7 @@ class MeaningDataset(Dataset):
|
|||
val_mask, stride_mask = self.get_seq_mask_tensor(idx_list)
|
||||
output["val_mask"] = val_mask
|
||||
labels = data.clone()
|
||||
labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
|
||||
# labels[~stride_mask] = self.vocab_size # set to vocab_size will be masked in label
|
||||
output["labels"] = labels
|
||||
output["meaning"] = [self.seq_meaning[i] for i in idx_list]
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue