From f1394d5974f0af164b7935c98fda7e089704a6e0 Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 8 Mar 2024 20:46:42 +0800 Subject: [PATCH] Refine code. --- wit/lit_module.py | 18 +++--------------- wit/train.py | 13 ++++++++++--- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/wit/lit_module.py b/wit/lit_module.py index 1015be7..10e1fab 100644 --- a/wit/lit_module.py +++ b/wit/lit_module.py @@ -59,11 +59,11 @@ class LitModule(pl.LightningModule): labels = batch["labels"][..., 1:] labels = labels.contiguous().view(-1) label_mask = labels < self.vocab_size - logits = logits[label_mask] - labels = labels[label_mask] + logits_m = logits[label_mask] + labels_m = labels[label_mask] # m = torch.max(logits, 1).indices.cpu().numpy() # ll = labels.cpu().numpy() - self.metric_accuracy.update(logits, labels) + self.metric_accuracy.update(logits_m, labels_m) self.metric_loss.update(loss) def on_validation_epoch_end(self) -> None: @@ -71,18 +71,6 @@ class LitModule(pl.LightningModule): self.log("accuracy", self.metric_accuracy, rank_zero_only=True) def configure_optimizers(self): - strategy = self.trainer.strategy - if isinstance(strategy, pl.strategies.DeepSpeedStrategy): - assert "optimizer" not in strategy.config - zero_config = strategy.config.get("zero_optimization") - if zero_config is not None: - if "offload_optimizer" in zero_config: - import deepspeed - - optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam( - self.trainer.model.parameters(), lr=self.learning_rate - ) - return optimizer optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate) return optimizer diff --git a/wit/train.py b/wit/train.py index f6062e1..fc93eac 100644 --- a/wit/train.py +++ b/wit/train.py @@ -35,19 +35,26 @@ vocab_size = 4096 class SpecialDataset(Dataset): - def __init__(self, start=1, end=320, size=32768): # 1048576 + def __init__(self, start=1, end=16, size=32768): # 1048576 32768 self.size = size self.features = [] a = torch.randint(start, end, [size]) b = torch.randint(start, end, [size]) c = torch.randint(start, end, [size]) d = torch.randint(start, end, [size]) - # self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0) - self.data = torch.stack([a, a, a + a]).permute(1, 0) + z = torch.zeros([size]).long() + # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0) + # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long() + # self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long() + self.data = torch.stack([a, b, a]).permute(1, 0).long() + # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long() # a = torch.randint(start, end, [size]) # self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5 # self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1 + # 只能有一种算法,而且第一个值不能用于训练 + # 太陡峭的过度导致难以拟合 + # 搜索空间太大,难以拟合 def __len__(self): return self.size