Refine code.
This commit is contained in:
parent
601c7f6510
commit
f1394d5974
|
@ -59,11 +59,11 @@ class LitModule(pl.LightningModule):
|
|||
labels = batch["labels"][..., 1:]
|
||||
labels = labels.contiguous().view(-1)
|
||||
label_mask = labels < self.vocab_size
|
||||
logits = logits[label_mask]
|
||||
labels = labels[label_mask]
|
||||
logits_m = logits[label_mask]
|
||||
labels_m = labels[label_mask]
|
||||
# m = torch.max(logits, 1).indices.cpu().numpy()
|
||||
# ll = labels.cpu().numpy()
|
||||
self.metric_accuracy.update(logits, labels)
|
||||
self.metric_accuracy.update(logits_m, labels_m)
|
||||
self.metric_loss.update(loss)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
|
@ -71,18 +71,6 @@ class LitModule(pl.LightningModule):
|
|||
self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
strategy = self.trainer.strategy
|
||||
if isinstance(strategy, pl.strategies.DeepSpeedStrategy):
|
||||
assert "optimizer" not in strategy.config
|
||||
zero_config = strategy.config.get("zero_optimization")
|
||||
if zero_config is not None:
|
||||
if "offload_optimizer" in zero_config:
|
||||
import deepspeed
|
||||
|
||||
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
|
||||
self.trainer.model.parameters(), lr=self.learning_rate
|
||||
)
|
||||
return optimizer
|
||||
optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
|
13
wit/train.py
13
wit/train.py
|
@ -35,19 +35,26 @@ vocab_size = 4096
|
|||
|
||||
|
||||
class SpecialDataset(Dataset):
|
||||
def __init__(self, start=1, end=320, size=32768): # 1048576
|
||||
def __init__(self, start=1, end=16, size=32768): # 1048576 32768
|
||||
self.size = size
|
||||
self.features = []
|
||||
a = torch.randint(start, end, [size])
|
||||
b = torch.randint(start, end, [size])
|
||||
c = torch.randint(start, end, [size])
|
||||
d = torch.randint(start, end, [size])
|
||||
# self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
|
||||
self.data = torch.stack([a, a, a + a]).permute(1, 0)
|
||||
z = torch.zeros([size]).long()
|
||||
# self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
|
||||
# self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
|
||||
# self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long()
|
||||
self.data = torch.stack([a, b, a]).permute(1, 0).long()
|
||||
# self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
|
||||
|
||||
# a = torch.randint(start, end, [size])
|
||||
# self.data = torch.stack([a, a, a + a]).permute(1, 0) # accuracy=0.5
|
||||
# self.data = torch.stack([a, a + a, a]).permute(1, 0) # accuracy=1
|
||||
# 只能有一种算法,而且第一个值不能用于训练
|
||||
# 太陡峭的过度导致难以拟合
|
||||
# 搜索空间太大,难以拟合
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
|
Loading…
Reference in New Issue