Source code for pywick.callbacks.ExperimentLogger

import csv
import os
import shutil
from collections import OrderedDict
from tempfile import NamedTemporaryFile

from . import Callback

__all__ = ['ExperimentLogger']

[docs]class ExperimentLogger(Callback): """ Generic logger callback for dumping experiment data. Can be extended for more utility. """ def __init__(self, directory, filename='Experiment_Logger.csv', save_prefix='Model_', separator=',', append=True): self.directory = directory self.filename = filename self.file = os.path.join(self.directory, self.filename) self.save_prefix = save_prefix self.sep = separator self.append = append self.writer = None self.keys = None self.append_header = True super(ExperimentLogger, self).__init__() def on_train_begin(self, logs=None): if self.append: open_type = 'a' else: open_type = 'w' # if append is True, find whether the file already has header num_lines = 0 if self.append: if os.path.exists(self.file): with open(self.file) as f: for num_lines, l in enumerate(f): pass # if header exists, DONT append header again with open(self.file) as f: self.append_header = not bool(len(f.readline())) model_idx = num_lines REJECT_KEYS = {'has_validation_data'} MODEL_NAME = self.save_prefix + str(model_idx) # figure out how to get model name self.row_dict = OrderedDict({'model': MODEL_NAME}) self.keys = sorted(logs.keys()) for k in self.keys: if k not in REJECT_KEYS: self.row_dict[k] = logs[k] class CustomDialect(csv.excel): delimiter = self.sep with open(self.file, open_type) as csv_file: writer = csv.DictWriter(csv_file, fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], dialect=CustomDialect) if self.append_header: writer.writeheader() writer.writerow(self.row_dict) csv_file.flush() def on_train_end(self, logs=None): REJECT_KEYS = {'has_validation_data'} row_dict = self.row_dict class CustomDialect(csv.excel): delimiter = self.sep self.keys = self.keys temp_file = NamedTemporaryFile(delete=False, mode='w') with open(self.file, 'r') as csv_file, temp_file: reader = csv.DictReader(csv_file, fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], dialect=CustomDialect) writer = csv.DictWriter(temp_file, fieldnames=['model'] + [k for k in self.keys if k not in REJECT_KEYS], dialect=CustomDialect) for row_idx, row in enumerate(reader): if row_idx == 0: # re-write header with on_train_end's metrics pass if row['model'] == self.row_dict['model']: writer.writerow(row_dict) else: writer.writerow(row) shutil.move(temp_file.name, self.file)