Source code for pywick.callbacks.CSVLogger

import csv
import os
from import Iterable
from collections import OrderedDict

import torch

from . import Callback

__all__ = ['CSVLogger']

[docs]class CSVLogger(Callback): """ Logs epoch-level metrics to a CSV file :param file: (string) path to csv file :param separator: (string) delimiter for file :param append: (bool) whether to append result to existing file or make new file """ def __init__(self, file, separator=',', append=False): self.file = file self.sep = separator self.append = append self.writer = None self.keys = None self.append_header = True super(CSVLogger, self).__init__() def on_train_begin(self, logs=None): if self.append: if os.path.exists(self.file): with open(self.file) as f: self.append_header = not bool(len(f.readline())) self.csv_file = open(self.file, 'a') else: self.csv_file = open(self.file, 'w') def on_epoch_end(self, epoch, logs=None): logs = logs or {} RK = {'num_batches', 'num_epoch'} def handle_value(k): is_zero_dim_tensor = isinstance(k, torch.Tensor) and k.dim() == 0 if isinstance(k, Iterable) and not is_zero_dim_tensor: return '"[%s]"' % (', '.join(map(str, k))) else: return k if not self.writer: self.keys = sorted(logs.keys()) class CustomDialect(csv.excel): delimiter = self.sep self.writer = csv.DictWriter(self.csv_file, fieldnames=['epoch'] + [k for k in self.keys if k not in RK], dialect=CustomDialect) if self.append_header: self.writer.writeheader() row_dict = OrderedDict({'epoch': epoch}) row_dict.update((key, handle_value(logs[key])) for key in self.keys if key not in RK) self.writer.writerow(row_dict) self.csv_file.flush() def on_train_end(self, logs=None): self.csv_file.close() self.writer = None