Update meaning dataset function.
This commit is contained in:
parent
43e486aa1c
commit
6791264987
|
@ -63,12 +63,14 @@ class LitModule(pl.LightningModule):
|
||||||
logits = logits.contiguous().view(-1, logits.size(-1))
|
logits = logits.contiguous().view(-1, logits.size(-1))
|
||||||
labels = batch["labels"][..., 1:]
|
labels = batch["labels"][..., 1:]
|
||||||
labels = labels.contiguous().view(-1)
|
labels = labels.contiguous().view(-1)
|
||||||
label_mask = batch["mask"][..., 1:]
|
if batch["mask"] != None:
|
||||||
label_mask = label_mask.contiguous().view(-1)
|
label_mask = batch["mask"][..., 1:]
|
||||||
logits_m = logits[label_mask]
|
label_mask = label_mask.contiguous().view(-1)
|
||||||
labels_m = labels[label_mask]
|
logits = logits[label_mask]
|
||||||
self.metric_accuracy.update(logits_m, labels_m)
|
labels = labels[label_mask]
|
||||||
self.metric_loss.update(loss)
|
if logits.numel() != 0 and labels.numel() != 0:
|
||||||
|
self.metric_accuracy.update(logits, labels)
|
||||||
|
self.metric_loss.update(loss)
|
||||||
|
|
||||||
def on_validation_epoch_end(self) -> None:
|
def on_validation_epoch_end(self) -> None:
|
||||||
self.log("val_loss", self.metric_loss, rank_zero_only=True)
|
self.log("val_loss", self.metric_loss, rank_zero_only=True)
|
||||||
|
|
|
@ -103,11 +103,6 @@ class MeaningMap:
|
||||||
for i, newm in enumerate(m_list)
|
for i, newm in enumerate(m_list)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
# ml = ml[ma > 0] # cut off the 0 token, such as [12,32,0,42,32]
|
|
||||||
# mr = mr[ma > 0]
|
|
||||||
# mrl = mrl[ma > 0]
|
|
||||||
# ma = ma[ma > 0]
|
|
||||||
|
|
||||||
ms_data.append(ma)
|
ms_data.append(ma)
|
||||||
ms_level.append(ml)
|
ms_level.append(ml)
|
||||||
ms_rank_idx.append(mr)
|
ms_rank_idx.append(mr)
|
||||||
|
@ -199,20 +194,53 @@ class MeaningMap:
|
||||||
return max(self.ms_len)
|
return max(self.ms_len)
|
||||||
|
|
||||||
def get_tree_str(tree, prefix):
|
def get_tree_str(tree, prefix):
|
||||||
if isinstance(tree, dict):
|
if isinstance(tree, list):
|
||||||
base = ""
|
base = ""
|
||||||
last_is_dict = None
|
for t in tree:
|
||||||
for key, value in tree.items():
|
base += MeaningMap.get_tree_str(t, "")
|
||||||
new_prefix = (len(str(key)) + 2) * " " + prefix
|
|
||||||
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
|
||||||
if dict_string:
|
|
||||||
base += "\n" + prefix + str(key) + ": " + dict_string
|
|
||||||
last_is_dict = True
|
|
||||||
else:
|
|
||||||
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
|
||||||
last_is_dict = False
|
|
||||||
return base
|
return base
|
||||||
return None
|
else:
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
base = ""
|
||||||
|
last_is_dict = None
|
||||||
|
for key, value in tree.items():
|
||||||
|
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||||
|
dict_string = MeaningMap.get_tree_str(value, new_prefix)
|
||||||
|
if dict_string:
|
||||||
|
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||||
|
last_is_dict = True
|
||||||
|
else:
|
||||||
|
base += "\n" + prefix + str(key) + " " if last_is_dict else str(key) + " "
|
||||||
|
last_is_dict = False
|
||||||
|
return base
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_tree_indexed_str(tree, data, prefix):
|
||||||
|
if isinstance(tree, list):
|
||||||
|
base = ""
|
||||||
|
qlen = 0
|
||||||
|
for i, t in enumerate(tree):
|
||||||
|
s, l = MeaningMap.get_tree_indexed_str(t, data[i], "")
|
||||||
|
base += s
|
||||||
|
qlen += l
|
||||||
|
return (base, qlen)
|
||||||
|
else:
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
base = ""
|
||||||
|
qlen = 0
|
||||||
|
last_is_dict = None
|
||||||
|
for key, value in tree.items():
|
||||||
|
new_prefix = (len(str(key)) + 2) * " " + prefix
|
||||||
|
dict_string, l = MeaningMap.get_tree_indexed_str(value, data[qlen:], new_prefix)
|
||||||
|
if dict_string:
|
||||||
|
base += "\n" + prefix + str(key) + ": " + dict_string
|
||||||
|
last_is_dict = True
|
||||||
|
else:
|
||||||
|
base += "\n" + prefix + str(data[qlen]) + " " if last_is_dict else str(data[qlen]) + " "
|
||||||
|
last_is_dict = False
|
||||||
|
qlen += l
|
||||||
|
return (base, qlen)
|
||||||
|
return (None, 1)
|
||||||
|
|
||||||
def token_frequency(tree, freq):
|
def token_frequency(tree, freq):
|
||||||
if isinstance(tree, dict):
|
if isinstance(tree, dict):
|
||||||
|
@ -280,22 +308,16 @@ class MeaningDataset(Dataset):
|
||||||
return len(self.seq)
|
return len(self.seq)
|
||||||
|
|
||||||
def set_mask(self, level=None, idx=None):
|
def set_mask(self, level=None, idx=None):
|
||||||
|
if self.mask_level is not None and self.mask_idx is not None:
|
||||||
|
assert len(self.mask_level) > 0, "len must > 0"
|
||||||
|
assert len(self.mask_level) == len(self.mask_idx), "mask level and mask index must be same length"
|
||||||
|
assert isinstance(self.mask_level, list), "mask level must be list"
|
||||||
|
assert isinstance(self.mask_idx, list), "mask index must be list"
|
||||||
self.mask_level = level
|
self.mask_level = level
|
||||||
self.mask_idx = idx
|
self.mask_idx = idx
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
output = {}
|
return self.get_batch([idx])
|
||||||
data = torch.tensor(self.seq[idx]).long()
|
|
||||||
output["input_ids"] = data
|
|
||||||
output["labels"] = data.clone()
|
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
|
||||||
output["tree"] = self.tree[idx]
|
|
||||||
output["level"] = self.level[idx]
|
|
||||||
if self.mask_level is not None and self.mask_idx is not None:
|
|
||||||
output["mask"] = torch.tensor(self.get_seq_mask(idx, self.mask_level, self.mask_idx))
|
|
||||||
else:
|
|
||||||
output["mask"] = torch.ones(data.shape, dtype=torch.long)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def get_batch(self, idx_list): # must equal sequence length
|
def get_batch(self, idx_list): # must equal sequence length
|
||||||
data = [self.seq[i] for i in idx_list]
|
data = [self.seq[i] for i in idx_list]
|
||||||
|
@ -306,12 +328,7 @@ class MeaningDataset(Dataset):
|
||||||
output["token_type_ids"] = torch.zeros(data.shape)
|
output["token_type_ids"] = torch.zeros(data.shape)
|
||||||
output["tree"] = [self.tree[i] for i in idx_list]
|
output["tree"] = [self.tree[i] for i in idx_list]
|
||||||
output["level"] = [self.level[i] for i in idx_list]
|
output["level"] = [self.level[i] for i in idx_list]
|
||||||
if self.mask_level is not None and self.mask_idx is not None:
|
output["mask"] = self.get_seq_mask_tensor(idx_list)
|
||||||
output["mask"] = torch.tensor(
|
|
||||||
np.stack([self.get_seq_mask(i, self.mask_level, self.mask_idx) for i in idx_list], axis=0)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
output["mask"] = torch.ones(data.shape, dtype=torch.long)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def get_token(self, idx): # must equal sequence length
|
def get_token(self, idx): # must equal sequence length
|
||||||
|
@ -335,6 +352,8 @@ class MeaningDataset(Dataset):
|
||||||
new.rank_idx = new.rank_idx[start:end]
|
new.rank_idx = new.rank_idx[start:end]
|
||||||
new.rank_all = new.rank_all[start:end]
|
new.rank_all = new.rank_all[start:end]
|
||||||
new.seq_meaning = new.seq_meaning[start:end]
|
new.seq_meaning = new.seq_meaning[start:end]
|
||||||
|
new.mask_level = self.mask_level
|
||||||
|
new.mask_idx = self.mask_idx
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def split(self, ratio):
|
def split(self, ratio):
|
||||||
|
@ -355,6 +374,19 @@ class MeaningDataset(Dataset):
|
||||||
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
rank_all = (self.rank_all[idx] >> (4 * level)).astype(np.int32) & 0xF
|
||||||
return rank_idx == (rank_all + index if index < 0 else index)
|
return rank_idx == (rank_all + index if index < 0 else index)
|
||||||
|
|
||||||
|
def get_seq_mask_tensor(self, idx_list):
|
||||||
|
if self.mask_level is not None and self.mask_idx is not None:
|
||||||
|
mask = torch.tensor(
|
||||||
|
np.stack([self.get_seq_mask(idx, self.mask_level[0], self.mask_idx[0]) for idx in idx_list], axis=0)
|
||||||
|
)
|
||||||
|
for i, l in enumerate(self.mask_level[1:]):
|
||||||
|
mask = mask & torch.tensor(
|
||||||
|
np.stack([self.get_seq_mask(idx, l, self.mask_idx[i]) for idx in idx_list], axis=0)
|
||||||
|
)
|
||||||
|
return mask
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class BatchGroupMeaningDataloader(Dataset):
|
class BatchGroupMeaningDataloader(Dataset):
|
||||||
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
def __init__(self, dataset: MeaningDataset, batch_size, shuffle=True, drop_last=True):
|
||||||
|
@ -414,7 +446,8 @@ class BatchGroupMeaningDataloader(Dataset):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=False)
|
md = MeaningDataset(100000, 115200, vocab_size=1024, size=1024, use_cache=True)
|
||||||
|
md.set_mask([1], [-1])
|
||||||
train, val = md.split(0.95)
|
train, val = md.split(0.95)
|
||||||
fdaf = md.__getitem__(920)
|
fdaf = md.__getitem__(920)
|
||||||
print(md.print_tree(920))
|
print(md.print_tree(920))
|
||||||
|
@ -496,12 +529,18 @@ if __name__ == "__main__":
|
||||||
), "False"
|
), "False"
|
||||||
freq = md.token_frequency()
|
freq = md.token_frequency()
|
||||||
|
|
||||||
dl = BatchGroupMeaningDataloader(train, 2)
|
dl = BatchGroupMeaningDataloader(val, 1)
|
||||||
# length = len(dl)
|
length = len(dl)
|
||||||
# it = iter(dl)
|
it = iter(dl)
|
||||||
# ne1 = next(it)
|
ne1 = next(it)
|
||||||
# ne2 = next(it)
|
tree = ne1["tree"]
|
||||||
# ne3 = next(it)
|
mask = ne1["mask"].cpu().numpy()
|
||||||
|
t = MeaningMap.get_tree_str(tree, "")
|
||||||
|
print(t)
|
||||||
|
m, l = MeaningMap.get_tree_indexed_str(tree, mask, "")
|
||||||
|
print(m)
|
||||||
|
ne2 = next(it)
|
||||||
|
ne3 = next(it)
|
||||||
|
|
||||||
# map1 = dl.get_tree(0)
|
# map1 = dl.get_tree(0)
|
||||||
# map2 = dl.get_tree(1)
|
# map2 = dl.get_tree(1)
|
||||||
|
|
11
wit/train.py
11
wit/train.py
|
@ -34,11 +34,16 @@ hidden_size = 1024 # 128 1024 2048 32
|
||||||
num_attention_heads = 16 # 8 8 16
|
num_attention_heads = 16 # 8 8 16
|
||||||
num_hidden_layers = 3 # 6 12 24 3
|
num_hidden_layers = 3 # 6 12 24 3
|
||||||
|
|
||||||
mask_level = None
|
mask_level = [0, 1]
|
||||||
mask_idx = None
|
mask_idx = [0, 0]
|
||||||
|
|
||||||
|
# mask_level = [0, 1]
|
||||||
|
# mask_idx = [0, -1]
|
||||||
|
|
||||||
# name = "vocab_ratio_level_data_hidden_head_layer"
|
# name = "vocab_ratio_level_data_hidden_head_layer"
|
||||||
name = "rank"
|
# name = "mask_level_idx"
|
||||||
|
name = "single_token"
|
||||||
|
|
||||||
ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
|
ver = f"{vocab_size}" + "_" + f"{level_ratio}" + "_" + f"{level}" + "_" + f"{dataset_level}"
|
||||||
ver = ver + "_" + f"{hidden_size}" + "_" + f"{num_attention_heads}" + "_" + f"{num_hidden_layers}"
|
ver = ver + "_" + f"{hidden_size}" + "_" + f"{num_attention_heads}" + "_" + f"{num_hidden_layers}"
|
||||||
ver = ver + "_" + f"{mask_level}" + "_" + f"{mask_idx}"
|
ver = ver + "_" + f"{mask_level}" + "_" + f"{mask_idx}"
|
||||||
|
|
Loading…
Reference in New Issue