Source code for pywick.callbacks.TQDM

from tqdm import tqdm
from . import Callback

__all__ = ['TQDM']


[docs]class TQDM(Callback): """ Callback that displays progress bar and useful statistics in terminal """ def __init__(self): """ TQDM Progress Bar callback This callback is automatically applied to every SuperModule if verbose > 0 """ self.progbar = None self.train_logs = None super(TQDM, self).__init__() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # make sure the dbconnection gets closed if self.progbar is not None: self.progbar.close() def on_train_begin(self, logs=None): self.train_logs = logs def on_epoch_begin(self, epoch, logs=None): try: self.progbar = tqdm(total=self.train_logs['num_batches'], unit=' batches') self.progbar.set_description('Epoch %i/%i' % (epoch + 1, self.train_logs['num_epoch'])) except: pass def on_epoch_end(self, epoch, logs=None): log_data = {key: '%.04f' % value for key, value in self.trainer.history.batch_metrics.items()} for k, v in logs.items(): if k.endswith('metric'): log_data[k.split('_metric')[0]] = '%.02f' % v else: log_data[k] = v log_data['learn_rates'] = self.trainer.history.lrs self.progbar.set_postfix(log_data) self.progbar.update() self.progbar.close() def on_batch_begin(self, batch, logs=None): self.progbar.update(1) def on_batch_end(self, batch, logs=None): log_data = {key: '%.04f' % value for key, value in self.trainer.history.batch_metrics.items()} for k, v in logs.items(): if k.endswith('metric'): log_data[k.split('_metric')[0]] = '%.02f' % v log_data['learn_rates'] = self.trainer.history.lrs self.progbar.set_postfix(log_data)