Refine code.
This commit is contained in:
parent
601c7f6510
commit
f1394d5974
|
@ -59,11 +59,11 @@ class LitModule(pl.LightningModule):
|
||||||
labels = batch["labels"][..., 1:]
|
labels = batch["labels"][..., 1:]
|
||||||
labels = labels.contiguous().view(-1)
|
labels = labels.contiguous().view(-1)
|
||||||
label_mask = labels < self.vocab_size
|
label_mask = labels < self.vocab_size
|
||||||
logits = logits[label_mask]
|
logits_m = logits[label_mask]
|
||||||
labels = labels[label_mask]
|
labels_m = labels[label_mask]
|
||||||
# m = torch.max(logits, 1).indices.cpu().numpy()
|
# m = torch.max(logits, 1).indices.cpu().numpy()
|
||||||
# ll = labels.cpu().numpy()
|
# ll = labels.cpu().numpy()
|
||||||
self.metric_accuracy.update(logits, labels)
|
self.metric_accuracy.update(logits_m, labels_m)
|
||||||
self.metric_loss.update(loss)
|
self.metric_loss.update(loss)
|
||||||
|
|
||||||
def on_validation_epoch_end(self) -> None:
|
def on_validation_epoch_end(self) -> None:
|
||||||
|
@ -71,18 +71,6 @@ class LitModule(pl.LightningModule):
|
||||||
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
|
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
strategy = self.trainer.strategy
|
|
||||||
if isinstance(strategy, pl.strategies.DeepSpeedStrategy):
|
|
||||||
assert "optimizer" not in strategy.config
|
|
||||||
zero_config = strategy.config.get("zero_optimization")
|
|
||||||
if zero_config is not None:
|
|
||||||
if "offload_optimizer" in zero_config:
|
|
||||||
import deepspeed
|
|
||||||
|
|
||||||
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
|
|
||||||
self.trainer.model.parameters(), lr=self.learning_rate
|
|
||||||
)
|
|
||||||
return optimizer
|
|
||||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
13
wit/train.py
13
wit/train.py
|
@ -35,19 +35,26 @@ vocab_size = 4096
|
||||||
|
|
||||||
|
|
||||||
class SpecialDataset(Dataset):
|
class SpecialDataset(Dataset):
|
||||||
def __init__(self, start=1, end=320, size=32768): # 1048576
|
def __init__(self, start=1, end=16, size=32768): # 1048576 32768
|
||||||
self.size = size
|
self.size = size
|
||||||
self.features = []
|
self.features = []
|
||||||
a = torch.randint(start, end, [size])
|
a = torch.randint(start, end, [size])
|
||||||
b = torch.randint(start, end, [size])
|
b = torch.randint(start, end, [size])
|
||||||
c = torch.randint(start, end, [size])
|
c = torch.randint(start, end, [size])
|
||||||
d = torch.randint(start, end, [size])
|
d = torch.randint(start, end, [size])
|
||||||
# self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
|
z = torch.zeros([size]).long()
|
||||||
self.data = torch.stack([a, a, a + a]).permute(1, 0)
|
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
|
||||||
|
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
|
||||||
|
# self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long()
|
||||||
|
self.data = torch.stack([a, b, a]).permute(1, 0).long()
|
||||||
|
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
|
||||||
|
|
||||||
# a = torch.randint(start, end, [size])
|
# a = torch.randint(start, end, [size])
|
||||||
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
|
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
|
||||||
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
|
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
|
||||||
|
# 只能有一种算法,而且第一个值不能用于训练
|
||||||
|
# 太陡峭的过度导致难以拟合
|
||||||
|
# 搜索空间太大,难以拟合
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.size
|
return self.size
|
||||||
|
|
Loading…
Reference in New Issue