import numpy as np
import torch
import numbers
from . import meter
[docs]class ClassErrorMeter(meter.Meter):
def __init__(self, topk=None, accuracy=False):
if topk is None:
topk = [1]
super(ClassErrorMeter, self).__init__()
self.topk = np.sort(topk)
self.accuracy = accuracy
self.reset()
[docs] def reset(self):
self.sum = {v: 0 for v in self.topk}
self.n = 0
[docs] def add(self, output, target):
if torch.is_tensor(output):
output = output.cpu().squeeze().numpy()
if torch.is_tensor(target):
target = np.atleast_1d(target.cpu().squeeze().numpy())
elif isinstance(target, numbers.Number):
target = np.asarray([target])
if np.ndim(output) == 1:
output = output[np.newaxis]
else:
if np.ndim(output) != 2:
raise AssertionError('wrong output size (1D or 2D expected)')
if np.ndim(target) != 1:
raise AssertionError('target and output do not match')
if target.shape[0] != output.shape[0]:
raise AssertionError('target and output do not match')
topk = self.topk
maxk = int(topk[-1]) # seems like Python3 wants int and not np.int64
no = output.shape[0]
pred = torch.from_numpy(output).topk(maxk, 1, True, True)[1].numpy()
correct = pred == target[:, np.newaxis].repeat(pred.shape[1], 1)
for k in topk:
self.sum[k] += no - correct[:, 0:k].sum()
self.n += no
[docs] def value(self, k=-1):
if k != -1:
if k not in self.sum.keys():
raise AssertionError('invalid k (this k was not provided at construction time)')
if self.accuracy:
return (1. - float(self.sum[k]) / self.n) * 100.0
else:
return float(self.sum[k]) / self.n * 100.0
else:
return [self.value(k_) for k_ in self.topk]