Refine code.

This commit is contained in:
Colin 2024-03-08 20:46:42 +08:00
parent 601c7f6510
commit f1394d5974
2 changed files with 13 additions and 18 deletions

View File

@ -59,11 +59,11 @@ class LitModule(pl.LightningModule):
labels = batch["labels"][..., 1:]
labels = labels.contiguous().view(-1)
label_mask = labels < self.vocab_size
logits = logits[label_mask]
labels = labels[label_mask]
logits_m = logits[label_mask]
labels_m = labels[label_mask]
# m = torch.max(logits, 1).indices.cpu().numpy()
# ll = labels.cpu().numpy()
self.metric_accuracy.update(logits, labels)
self.metric_accuracy.update(logits_m, labels_m)
self.metric_loss.update(loss)
def on_validation_epoch_end(self) -> None:
@ -71,18 +71,6 @@ class LitModule(pl.LightningModule):
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
def configure_optimizers(self):
strategy = self.trainer.strategy
if isinstance(strategy, pl.strategies.DeepSpeedStrategy):
assert "optimizer" not in strategy.config
zero_config = strategy.config.get("zero_optimization")
if zero_config is not None:
if "offload_optimizer" in zero_config:
import deepspeed
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
self.trainer.model.parameters(), lr=self.learning_rate
)
return optimizer
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
return optimizer

View File

@ -35,19 +35,26 @@ vocab_size = 4096
class SpecialDataset(Dataset):
def __init__(self, start=1, end=320, size=32768): # 1048576
def __init__(self, start=1, end=16, size=32768): # 1048576 32768
self.size = size
self.features = []
a = torch.randint(start, end, [size])
b = torch.randint(start, end, [size])
c = torch.randint(start, end, [size])
d = torch.randint(start, end, [size])
# self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
self.data = torch.stack([a, a, a + a]).permute(1, 0)
z = torch.zeros([size]).long()
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
# self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long()
self.data = torch.stack([a, b, a]).permute(1, 0).long()
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
# a = torch.randint(start, end, [size])
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
# 只能有一种算法,而且第一个值不能用于训练
# 太陡峭的过度导致难以拟合
# 搜索空间太大,难以拟合
def __len__(self):
return self.size