from pytorch_lightning import loggers from pytorch_lightning.utilities import rank_zero_only class TBLogger(loggers.TensorBoardLogger): @rank_zero_only def log_metrics(self, metrics, step): metrics.pop("epoch", None) return super().log_metrics(metrics, step) class WBLogger(loggers.WandbLogger): @rank_zero_only def log_metrics(self, metrics, step): metrics.pop("epoch", None) return super().log_metrics(metrics, step) class MLFLogger(loggers.MLFlowLogger): @rank_zero_only def log_metrics(self, metrics, step): metrics.pop("epoch", None) return super().log_metrics(metrics, step)