Source code for pywick.meters.classerrormeter

import numpy as np
import torch
import numbers
from . import meter


[docs]class ClassErrorMeter(meter.Meter): def __init__(self, topk=None, accuracy=False): if topk is None: topk = [1] super(ClassErrorMeter, self).__init__() self.topk = np.sort(topk) self.accuracy = accuracy self.reset()
[docs] def reset(self): self.sum = {v: 0 for v in self.topk} self.n = 0
[docs] def add(self, output, target): if torch.is_tensor(output): output = output.cpu().squeeze().numpy() if torch.is_tensor(target): target = np.atleast_1d(target.cpu().squeeze().numpy()) elif isinstance(target, numbers.Number): target = np.asarray([target]) if np.ndim(output) == 1: output = output[np.newaxis] else: if np.ndim(output) != 2: raise AssertionError('wrong output size (1D or 2D expected)') if np.ndim(target) != 1: raise AssertionError('target and output do not match') if target.shape[0] != output.shape[0]: raise AssertionError('target and output do not match') topk = self.topk maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64 no = output.shape[0] pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy() correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1) for k in topk: self.sum[k] += no - correct[:, 0:k].sum() self.n += no
[docs] def value(self, k=-1): if k != -1: if k not in self.sum.keys(): raise AssertionError('invalid k (this k was not provided at construction time)') if self.accuracy: return (1. - float(self.sum[k]) / self.n) * 100.0 else: return float(self.sum[k]) / self.n * 100.0 else: return [self.value(k_) for k_ in self.topk]