Refine code.
This commit is contained in:
		
							parent
							
								
									601c7f6510
								
							
						
					
					
						commit
						f1394d5974
					
				| 
						 | 
					@ -59,11 +59,11 @@ class LitModule(pl.LightningModule):
 | 
				
			||||||
        labels = batch["labels"][..., 1:]
 | 
					        labels = batch["labels"][..., 1:]
 | 
				
			||||||
        labels = labels.contiguous().view(-1)
 | 
					        labels = labels.contiguous().view(-1)
 | 
				
			||||||
        label_mask = labels < self.vocab_size
 | 
					        label_mask = labels < self.vocab_size
 | 
				
			||||||
        logits = logits[label_mask]
 | 
					        logits_m = logits[label_mask]
 | 
				
			||||||
        labels = labels[label_mask]
 | 
					        labels_m = labels[label_mask]
 | 
				
			||||||
        # m = torch.max(logits, 1).indices.cpu().numpy()
 | 
					        # m = torch.max(logits, 1).indices.cpu().numpy()
 | 
				
			||||||
        # ll = labels.cpu().numpy()
 | 
					        # ll = labels.cpu().numpy()
 | 
				
			||||||
        self.metric_accuracy.update(logits, labels)
 | 
					        self.metric_accuracy.update(logits_m, labels_m)
 | 
				
			||||||
        self.metric_loss.update(loss)
 | 
					        self.metric_loss.update(loss)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_validation_epoch_end(self) -> None:
 | 
					    def on_validation_epoch_end(self) -> None:
 | 
				
			||||||
| 
						 | 
					@ -71,18 +71,6 @@ class LitModule(pl.LightningModule):
 | 
				
			||||||
        self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
 | 
					        self.log("accuracy", self.metric_accuracy, rank_zero_only=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def configure_optimizers(self):
 | 
					    def configure_optimizers(self):
 | 
				
			||||||
        strategy = self.trainer.strategy
 | 
					 | 
				
			||||||
        if isinstance(strategy, pl.strategies.DeepSpeedStrategy):
 | 
					 | 
				
			||||||
            assert "optimizer" not in strategy.config
 | 
					 | 
				
			||||||
            zero_config = strategy.config.get("zero_optimization")
 | 
					 | 
				
			||||||
            if zero_config is not None:
 | 
					 | 
				
			||||||
                if "offload_optimizer" in zero_config:
 | 
					 | 
				
			||||||
                    import deepspeed
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                    optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam(
 | 
					 | 
				
			||||||
                        self.trainer.model.parameters(), lr=self.learning_rate
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                    return optimizer
 | 
					 | 
				
			||||||
        optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
 | 
					        optimizer = torch.optim.AdamW(self.trainer.model.parameters(), lr=self.learning_rate)
 | 
				
			||||||
        return optimizer
 | 
					        return optimizer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										13
									
								
								wit/train.py
								
								
								
								
							
							
						
						
									
										13
									
								
								wit/train.py
								
								
								
								
							| 
						 | 
					@ -35,19 +35,26 @@ vocab_size = 4096
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SpecialDataset(Dataset):
 | 
					class SpecialDataset(Dataset):
 | 
				
			||||||
    def __init__(self, start=1, end=320, size=32768):  # 1048576
 | 
					    def __init__(self, start=1, end=16, size=32768):  # 1048576 32768
 | 
				
			||||||
        self.size = size
 | 
					        self.size = size
 | 
				
			||||||
        self.features = []
 | 
					        self.features = []
 | 
				
			||||||
        a = torch.randint(start, end, [size])
 | 
					        a = torch.randint(start, end, [size])
 | 
				
			||||||
        b = torch.randint(start, end, [size])
 | 
					        b = torch.randint(start, end, [size])
 | 
				
			||||||
        c = torch.randint(start, end, [size])
 | 
					        c = torch.randint(start, end, [size])
 | 
				
			||||||
        d = torch.randint(start, end, [size])
 | 
					        d = torch.randint(start, end, [size])
 | 
				
			||||||
        # self.data = torch.stack([a, b, a + b, a + b]).permute(1, 0)
 | 
					        z = torch.zeros([size]).long()
 | 
				
			||||||
        self.data = torch.stack([a, a, a + a]).permute(1, 0)
 | 
					        # self.data = torch.stack([a, b, a + b, a + b, a + b * 2]).permute(1, 0)
 | 
				
			||||||
 | 
					        # self.data = torch.stack([a, b, a, a + b / 4]).permute(1, 0).long()
 | 
				
			||||||
 | 
					        # self.data = torch.stack([a, a + 1, a + 2]).permute(1, 0).long()
 | 
				
			||||||
 | 
					        self.data = torch.stack([a, b, a]).permute(1, 0).long()
 | 
				
			||||||
 | 
					        # self.data = torch.stack([a, b, a, a + a / 8, a + a / 4, a + a / 2, a + a]).permute(1, 0).long()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # a = torch.randint(start, end, [size])
 | 
					        # a = torch.randint(start, end, [size])
 | 
				
			||||||
        # self.data = torch.stack([a, a, a + a]).permute(1, 0)  # accuracy=0.5
 | 
					        # self.data = torch.stack([a, a, a + a]).permute(1, 0)  # accuracy=0.5
 | 
				
			||||||
        # self.data = torch.stack([a, a + a, a]).permute(1, 0)  # accuracy=1
 | 
					        # self.data = torch.stack([a, a + a, a]).permute(1, 0)  # accuracy=1
 | 
				
			||||||
 | 
					        # 只能有一种算法,而且第一个值不能用于训练
 | 
				
			||||||
 | 
					        # 太陡峭的过度导致难以拟合
 | 
				
			||||||
 | 
					        # 搜索空间太大,难以拟合
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __len__(self):
 | 
					    def __len__(self):
 | 
				
			||||||
        return self.size
 | 
					        return self.size
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue