Witllm/wit/logger.py

26 lines
661 B
Python

from pytorch_lightning import loggers
from pytorch_lightning.utilities import rank_zero_only
class TBLogger(loggers.TensorBoardLogger):
@rank_zero_only
def log_metrics(self, metrics, step):
metrics.pop("epoch", None)
return super().log_metrics(metrics, step)
class WBLogger(loggers.WandbLogger):
@rank_zero_only
def log_metrics(self, metrics, step):
metrics.pop("epoch", None)
return super().log_metrics(metrics, step)
class MLFLogger(loggers.MLFlowLogger):
@rank_zero_only
def log_metrics(self, metrics, step):
metrics.pop("epoch", None)
return super().log_metrics(metrics, step)