Update meaning dataset function.
This commit is contained in:
parent
43e486aa1c
commit
6791264987
|
@ -63,12 +63,14 @@ class LitModule(pl.LightningModule):
|
|||
logits = logits.contiguous().view(-1, logits.size(-1))
|
||||
labels = batch["labels"][..., 1:]
|
||||
labels = labels.contiguous().view(-1)
|
||||
label_mask = batch["mask"][..., 1:]
|
||||
label_mask = label_mask.contiguous().view(-1)
|
||||
logits_m = logits[label_mask]
|
||||
labels_m = labels[label_mask]
|
||||
self.metric_accuracy.update(logits_m, labels_m)
|
||||
self.metric_loss.update(loss)
|
||||
if batch["mask"] != None:
|
||||
label_mask = batch["mask"][..., 1:]
|
||||
label_mask = label_mask.contiguous().view(-1)
|
||||
logits = logits[label_mask]
|
||||
labels = labels[label_mask]
|
||||
if logits.numel() != 0 and labels.numel() != 0:
|
||||
self.metric_accuracy.update(logits, labels)
|
||||
self.metric_loss.update(loss)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.log("val_loss", self.metric_loss, rank_zero_only=True)
|
||||
|
|
|
@ -103,11 +103,6 @@ class MeaningMap:
|
|||
for i, newm in enumerate(m_list)
|
||||
]
|
||||
)
|
||||
# ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32]
|
||||
# mr = mr[ma > 0]
|
||||
# mrl = mrl[ma > 0]
|
||||
# ma = ma[ma > 0]
|
||||
|
||||
ms_data.append(ma)
|
||||
ms_level.append(ml)
|
||||
ms_rank_idx.append(mr)
|
||||
|
@ -199,20 +194,53 @@ class MeaningMap:
|
|||
return max(self.ms_len)
|
||||
|
||||
def get_tree_str(tree, prefix):
|
||||
if isinstance(tree, dict):
|
||||
if isinstance(tree, list):
|
||||
base = ""
|
||||
last_is_dict = None
|
||||
for key, value in tree.items():
|
||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
||||
if dict_string:
|
||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||
last_is_dict = True
|
||||
else:
|
||||
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
||||
last_is_dict = False
|
||||
for t in tree:
|
||||
base += MeaningMap.get_tree_str(t, "")
|
||||
return base
|
||||
return None
|
||||
else:
|
||||
if isinstance(tree, dict):
|
||||
base = ""
|
||||
last_is_dict = None
|
||||
for key, value in tree.items():
|
||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
||||
if dict_string:
|
||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||
last_is_dict = True
|
||||
else:
|
||||
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
||||
last_is_dict = False
|
||||
return base
|
||||
return None
|
||||
|
||||
def get_tree_indexed_str(tree, data, prefix):
|
||||
if isinstance(tree, list):
|
||||
base = ""
|
||||
qlen = 0
|
||||
for i, t in enumerate(tree):
|
||||
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
||||
base += s
|
||||
qlen += l
|
||||
return (base, qlen)
|
||||
else:
|
||||
if isinstance(tree, dict):
|
||||
base = ""
|
||||
qlen = 0
|
||||
last_is_dict = None
|
||||
for key, value in tree.items():
|
||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
||||
if dict_string:
|
||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||
last_is_dict = True
|
||||
else:
|
||||
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
||||
last_is_dict = False
|
||||
qlen += l
|
||||
return (base, qlen)
|
||||
return (None, 1)
|
||||
|
||||
def token_frequency(tree, freq):
|
||||
if isinstance(tree, dict):
|
||||
|
@ -280,22 +308,16 @@ class MeaningDataset(Dataset):
|
|||
return len(self.seq)
|
||||
|
||||
def set_mask(self, level=None, idx=None):
|
||||
if self.mask_level is not None and self.mask_idx is not None:
|
||||
assert len(self.mask_level) > 0, "len must > 0"
|
||||
assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length"
|
||||
assert isinstance(self.mask_level, list), "mask level must be list"
|
||||
assert isinstance(self.mask_idx, list), "mask index must be list"
|
||||
self.mask_level = level
|
||||
self.mask_idx = idx
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = {}
|
||||
data = torch.tensor(self.seq[idx]).long()
|
||||
output["input_ids"] = data
|
||||
output["labels"] = data.clone()
|
||||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
output["tree"] = self.tree[idx]
|
||||
output["level"] = self.level[idx]
|
||||
if self.mask_level is not None and self.mask_idx is not None:
|
||||
output["mask"] = torch.tensor(self.get_seq_mask(idx, self.mask_level, self.mask_idx))
|
||||
else:
|
||||
output["mask"] = torch.ones(data.shape, dtype=torch.long)
|
||||
return output
|
||||
return self.get_batch([idx])
|
||||
|
||||
def get_batch(self, idx_list): # must equal sequence length
|
||||
data = [self.seq[i] for i in idx_list]
|
||||
|
@ -306,12 +328,7 @@ class MeaningDataset(Dataset):
|
|||
output["token_type_ids"] = torch.zeros(data.shape)
|
||||
output["tree"] = [self.tree[i] for i in idx_list]
|
||||
output["level"] = [self.level[i] for i in idx_list]
|
||||
if self.mask_level is not None and self.mask_idx is not None:
|
||||
output["mask"] = torch.tensor(
|
||||
np.stack([self.get_seq_mask(i, self.mask_level, self.mask_idx) for i in idx_list], axis=0)
|
||||
)
|
||||
else:
|
||||
output["mask"] = torch.ones(data.shape, dtype=torch.long)
|
||||
output["mask"] = self.get_seq_mask_tensor(idx_list)
|
||||
return output
|
||||
|
||||
def get_token(self, idx): # must equal sequence length
|
||||
|
@ -335,6 +352,8 @@ class MeaningDataset(Dataset):
|
|||
new.rank_idx = new.rank_idx[start:end]
|
||||
new.rank_all = new.rank_all[start:end]
|
||||
new.seq_meaning = new.seq_meaning[start:end]
|
||||
new.mask_level = self.mask_level
|
||||
new.mask_idx = self.mask_idx
|
||||
return new
|
||||
|
||||
def split(self, ratio):
|
||||
|
@ -355,6 +374,19 @@ class MeaningDataset(Dataset):
|
|||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||
return rank_idx == (rank_all + index if index < 0 else index)
|
||||
|
||||
def get_seq_mask_tensor(self, idx_list):
|
||||
if self.mask_level is not None and self.mask_idx is not None:
|
||||
mask = torch.tensor(
|
||||
np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0)
|
||||
)
|
||||
for i, l in enumerate(self.mask_level[1:]):
|
||||
mask = mask & torch.tensor(
|
||||
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0)
|
||||
)
|
||||
return mask
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class BatchGroupMeaningDataloader(Dataset):
|
||||
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
||||
|
@ -414,7 +446,8 @@ class BatchGroupMeaningDataloader(Dataset):
|
|||
|
||||
if __name__ == "__main__":
|
||||
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
|
||||
md.set_mask([1], [-1])
|
||||
train, val = md.split(0.95)
|
||||
fdaf = md.__getitem__(920)
|
||||
print(md.print_tree(920))
|
||||
|
@ -496,12 +529,18 @@ if __name__ == "__main__":
|
|||
), "False"
|
||||
freq = md.token_frequency()
|
||||
|
||||
dl = BatchGroupMeaningDataloader(train, 2)
|
||||
# length = len(dl)
|
||||
# it = iter(dl)
|
||||
# ne1 = next(it)
|
||||
# ne2 = next(it)
|
||||
# ne3 = next(it)
|
||||
dl = BatchGroupMeaningDataloader(val, 1)
|
||||
length = len(dl)
|
||||
it = iter(dl)
|
||||
ne1 = next(it)
|
||||
tree = ne1["tree"]
|
||||
mask = ne1["mask"].cpu().numpy()
|
||||
t = MeaningMap.get_tree_str(tree, "")
|
||||
print(t)
|
||||
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
|
||||
print(m)
|
||||
ne2 = next(it)
|
||||
ne3 = next(it)
|
||||
|
||||
# map1 = dl.get_tree(0)
|
||||
# map2 = dl.get_tree(1)
|
||||
|
|
11
wit/train.py
11
wit/train.py
|
@ -34,11 +34,16 @@ hidden_size = 1024 # 128 1024 2048 32
|
|||
num_attention_heads = 16 # 8 8 16
|
||||
num_hidden_layers = 3 # 6 12 24 3
|
||||
|
||||
mask_level = None
|
||||
mask_idx = None
|
||||
mask_level = [0, 1]
|
||||
mask_idx = [0, 0]
|
||||
|
||||
# mask_level = [0, 1]
|
||||
# mask_idx = [0, -1]
|
||||
|
||||
# name = "vocab_ratio_level_data_hidden_head_layer"
|
||||
name = "rank"
|
||||
# name = "mask_level_idx"
|
||||
name = "single_token"
|
||||
|
||||
ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
|
||||
ver = ver + "_" + f"{hidden_size}" + "_" + f"{num_attention_heads}" + "_" + f"{num_hidden_layers}"
|
||||
ver = ver + "_" + f"{mask_level}" + "_" + f"{mask_idx}"
|
||||
|
|
Loading…
Reference in New Issue