Update meaning dataset function.

This commit is contained in:
Colin 2024-04-12 20:04:04 +08:00
parent 43e486aa1c
commit 6791264987
3 changed files with 97 additions and 51 deletions

View File

@ -63,11 +63,13 @@ class LitModule(pl.LightningModule):
logits = logits.contiguous().view(-1, logits.size(-1))
labels = batch["labels"][..., 1:]
labels = labels.contiguous().view(-1)
if batch["mask"] != None:
label_mask = batch["mask"][..., 1:]
label_mask = label_mask.contiguous().view(-1)
logits_m = logits[label_mask]
labels_m = labels[label_mask]
self.metric_accuracy.update(logits_m, labels_m)
logits = logits[label_mask]
labels = labels[label_mask]
if logits.numel() != 0 and labels.numel() != 0:
self.metric_accuracy.update(logits, labels)
self.metric_loss.update(loss)
def on_validation_epoch_end(self) -> None:

View File

@ -103,11 +103,6 @@ class MeaningMap:
for i, newm in enumerate(m_list)
]
)
# ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32]
# mr = mr[ma > 0]
# mrl = mrl[ma > 0]
# ma = ma[ma > 0]
ms_data.append(ma)
ms_level.append(ml)
ms_rank_idx.append(mr)
@ -199,6 +194,12 @@ class MeaningMap:
return max(self.ms_len)
def get_tree_str(tree, prefix):
if isinstance(tree, list):
base = ""
for t in tree:
base += MeaningMap.get_tree_str(t, "")
return base
else:
if isinstance(tree, dict):
base = ""
last_is_dict = None
@ -214,6 +215,33 @@ class MeaningMap:
return base
return None
def get_tree_indexed_str(tree, data, prefix):
if isinstance(tree, list):
base = ""
qlen = 0
for i, t in enumerate(tree):
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
base += s
qlen += l
return (base, qlen)
else:
if isinstance(tree, dict):
base = ""
qlen = 0
last_is_dict = None
for key, value in tree.items():
new_prefix = (len(str(key)) + 2) * " " + prefix
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
if dict_string:
base += "\n" + prefix + str(key) + ": " + dict_string
last_is_dict = True
else:
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
last_is_dict = False
qlen += l
return (base, qlen)
return (None, 1)
def token_frequency(tree, freq):
if isinstance(tree, dict):
for key, value in tree.items():
@ -280,22 +308,16 @@ class MeaningDataset(Dataset):
return len(self.seq)
def set_mask(self, level=None, idx=None):
if self.mask_level is not None and self.mask_idx is not None:
assert len(self.mask_level) > 0, "len must > 0"
assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length"
assert isinstance(self.mask_level, list), "mask level must be list"
assert isinstance(self.mask_idx, list), "mask index must be list"
self.mask_level = level
self.mask_idx = idx
def __getitem__(self, idx):
output = {}
data = torch.tensor(self.seq[idx]).long()
output["input_ids"] = data
output["labels"] = data.clone()
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = self.tree[idx]
output["level"] = self.level[idx]
if self.mask_level is not None and self.mask_idx is not None:
output["mask"] = torch.tensor(self.get_seq_mask(idx, self.mask_level, self.mask_idx))
else:
output["mask"] = torch.ones(data.shape, dtype=torch.long)
return output
return self.get_batch([idx])
def get_batch(self, idx_list): # must equal sequence length
data = [self.seq[i] for i in idx_list]
@ -306,12 +328,7 @@ class MeaningDataset(Dataset):
output["token_type_ids"] = torch.zeros(data.shape)
output["tree"] = [self.tree[i] for i in idx_list]
output["level"] = [self.level[i] for i in idx_list]
if self.mask_level is not None and self.mask_idx is not None:
output["mask"] = torch.tensor(
np.stack([self.get_seq_mask(i, self.mask_level, self.mask_idx) for i in idx_list], axis=0)
)
else:
output["mask"] = torch.ones(data.shape, dtype=torch.long)
output["mask"] = self.get_seq_mask_tensor(idx_list)
return output
def get_token(self, idx): # must equal sequence length
@ -335,6 +352,8 @@ class MeaningDataset(Dataset):
new.rank_idx = new.rank_idx[start:end]
new.rank_all = new.rank_all[start:end]
new.seq_meaning = new.seq_meaning[start:end]
new.mask_level = self.mask_level
new.mask_idx = self.mask_idx
return new
def split(self, ratio):
@ -355,6 +374,19 @@ class MeaningDataset(Dataset):
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
return rank_idx == (rank_all + index if index < 0 else index)
def get_seq_mask_tensor(self, idx_list):
if self.mask_level is not None and self.mask_idx is not None:
mask = torch.tensor(
np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0)
)
for i, l in enumerate(self.mask_level[1:]):
mask = mask & torch.tensor(
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0)
)
return mask
else:
return None
class BatchGroupMeaningDataloader(Dataset):
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
@ -414,7 +446,8 @@ class BatchGroupMeaningDataloader(Dataset):
if __name__ == "__main__":
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
md.set_mask([1], [-1])
train, val = md.split(0.95)
fdaf = md.__getitem__(920)
print(md.print_tree(920))
@ -496,12 +529,18 @@ if __name__ == "__main__":
), "False"
freq = md.token_frequency()
dl = BatchGroupMeaningDataloader(train, 2)
# length = len(dl)
# it = iter(dl)
# ne1 = next(it)
# ne2 = next(it)
# ne3 = next(it)
dl = BatchGroupMeaningDataloader(val, 1)
length = len(dl)
it = iter(dl)
ne1 = next(it)
tree = ne1["tree"]
mask = ne1["mask"].cpu().numpy()
t = MeaningMap.get_tree_str(tree, "")
print(t)
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
print(m)
ne2 = next(it)
ne3 = next(it)
# map1 = dl.get_tree(0)
# map2 = dl.get_tree(1)

View File

@ -34,11 +34,16 @@ hidden_size = 1024 # 128 1024 2048 32
num_attention_heads = 16 # 8 8 16
num_hidden_layers = 3 # 6 12 24 3
mask_level = None
mask_idx = None
mask_level = [0, 1]
mask_idx = [0, 0]
# mask_level = [0, 1]
# mask_idx = [0, -1]
# name = "vocab_ratio_level_data_hidden_head_layer"
name = "rank"
# name = "mask_level_idx"
name = "single_token"
ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
ver = ver + "_" + f"{hidden_size}" + "_" + f"{num_attention_heads}" + "_" + f"{num_hidden_layers}"
ver = ver + "_" + f"{mask_level}" + "_" + f"{mask_idx}"