写在前面
本阶段目标
具体笔记
Multi-task和multi-label的区别
MXnet下定义multi-task的网络结构
定义Multi-task评价指标metric
单任务的accuracy函数:
官方的用于图像多标签(multi-label)分类的multi_accuracy函数:
cross-rentropy
recall
precision
写在前面
本系列博客记录了作者上手MXNet的全过程。作者在接触MXNet之前主要使用keras,和一点tensorflow,因此在上手MXNet之前有一点deep learning的项目基础。主要参考资料为MXNet官方教程,也阅读了一些有价值的博客。
博客结构为:先列出作者对于该阶段的期望目标,以及各目标完成过程中的笔记(仅记下个人认为重要的),再附上学习过程中自己的提问(solved & unsolved,天马行空的提问,欢迎讨论)。
本阶段目标
任务优先级预计花时间完成状态遇到问题补充定义Multi-task数据格式P02hour 定义Multi-task网络P10.5hour 定义Multi-task评价指标metricP21.5hour 网络训练以及评估
具体笔记
Multi-task和multi-label的区别
multi-task 比multi-label更复杂,网络的中间过程可以有分支
multi-label是特殊的multi-task。当每个task的分类取值都是二分类时,就是multi-label,但multi-task的每个任务可以是多分类
MXnet下定义multi-task的网络结构
代码:
图示:
网络结构在flatten0后出现了分支
定义Multi-task评价指标metric
网络上关于multi-task的metric资料很多,但基本都是multi_accuracy,在此整理了accuracy / cross-entropy / precision / recall 的单(多)任务版本。
单任务的accuracy函数:
import mxnet as mx
class Accuracy(mx.metric.EvalMetric):
def __init__(self, num=None):
super(Accuracy, self).__init__('accuracy', num)
def update(self, labels, preds):
pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
label = labels[0].asnumpy().astype('int32')
mx.metric.check_label_shapes(label, pred_label)
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
官方的用于图像多标签(multi-label)分类的multi_accuracy函数:
class Multi_Accuracy(mx.metric.EvalMetric):
"""Calculate accuracies of multi label"""
def __init__(self, num=None):
self.num = num
super(Multi_Accuracy, self).__init__('multi-accuracy')
def reset(self):
"""Resets the internal evaluation result to initial state."""
self.num_inst = 0 if self.num is None else [0] * self.num
self.sum_metric = 0.0 if self.num is None else [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
if self.num is not None:
assert len(labels) == self.num
for i in range(len(labels)):
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
label = labels[i].asnumpy().astype('int32')
mx.metric.check_label_shapes(label, pred_label)
if self.num is None:
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
else:
self.sum_metric[i] += (pred_label.flat == label.flat).sum()
self.num_inst[i] += len(pred_label.flat)
def get(self):
"""Gets the current evaluation result.
Returns
-------
names : list of str
Name of the metrics.
values : list of float
Value of the evaluations.
"""
if self.num is None:
return super(Multi_Accuracy, self).get()
else:
return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0 else self.sum_metric[i] / self.num_inst[i]) for i in range(self.num)))
def get_name_value(self):
"""Returns zipped name and value pairs.
Returns
-------
list of tuples
A (name, value) tuple list.
"""
if self.num is None:
return super(Multi_Accuracy, self).get_name_value()
name, value = self.get()
return list(zip(name, value))
调用时,修改Multi_Accuracy(num=3)的参数num,就可以指定计算出几个accuracy。
from my_metric import *
eval_metric = mx.metric.CompositeEvalMetric()
eval_metric.add(Multi_Accuracy(num=2))
以下的 cross-entropy / recall 和 precision 的metric函数,均可通过修改num和name来指定用于单任务还是多任务。
cross-rentropy
class CrossEntropy(mx.metric.EvalMetric):
def __init__(self, eps=1e-12, name='cross-entropy',
output_names=None, label_names=None, num=None):
super(CrossEntropy, self).__init__(
name, eps=eps,
output_names=output_names, label_names=label_names)
self.eps = eps
self.num = num
self.name = name
self.reset()
def reset(self):
if getattr(self, 'num', None) is None:
self.num_inst = 0
self.sum_metric = 0.0
else:
self.num_inst = [0] * self.num
self.sum_metric = [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
i = 0
for label, pred in zip(labels, preds):
label = label.asnumpy()
pred = pred.asnumpy()
label = label.ravel()
assert label.shape[0] == pred.shape[0]
if i == 1:
sexy_index = np.where(np.int64(label) == -1)
label[sexy_index] = 0.0 # random 0 or 1
pred[sexy_index] = np.ones((len(sexy_index),2)) # No loss for sexy image
prob = pred[np.arange(label.shape[0]), np.int64(label)]
if self.num is None:
self.sum_metric += (-np.log(prob + self.eps)).sum()
if i == 1:
self.num_inst += (label.shape[0] - len(sexy_index[0]))
else:
self.num_inst += label.shape[0]
else:
self.sum_metric[i] += (-np.log(prob + self.eps)).sum()
if i == 1:
self.num_inst[i] += (label.shape[0] - len(sexy_index[0]))
else:
self.num_inst[i] += label.shape[0]
i += 1
def get(self):
if self.num is None:
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
else:
result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
return (self.name, result)
recall
class Recall(mx.metric.EvalMetric):
def __init__(self, name, num=None):
super(Recall, self).__init__('Recall')
self.num = num
self.name = name
self.reset()
def reset(self):
if getattr(self, 'num', None) is None:
self.num_inst = 0
self.sum_metric = 0.0
else:
self.num_inst = [0] * self.num
self.sum_metric = [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
i = 0
for pred, label in zip(preds, labels):
pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
count_pred = 0
count_truth = 0
for index in range(len(pred.flat)):
if label[index] == -1:
continue
if pred[index] == 0 and label[index] == 0:
count_pred += 1
if label[index] == 0:
count_truth += 1
if self.num is None:
self.sum_metric += count_pred
self.num_inst += count_truth
else:
self.sum_metric[i] += count_pred
self.num_inst[i] += count_truth
i += 1
def get(self):
if self.num is None:
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
else:
result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
return (self.name, result)
precision
class Precision(mx.metric.EvalMetric):
def __init__(self, name, num=None):
super(Precision, self).__init__('Precision')
self.num = num
self.name = name
self.reset()
def reset(self):
if getattr(self, 'num', None) is None:
self.num_inst = 0
self.sum_metric = 0.0
else:
self.num_inst = [0] * self.num
self.sum_metric = [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
i = 0
for pred, label in zip(preds, labels):
pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
count_truth = 0
count_pred = 0
for index in range(len(pred.flat)):
if label[index] == -1:
continue
if pred[index] == 0 and label[index] == 0:
count_truth +=1
if pred[index] ==0:
count_pred +=1
if self.num is None:
self.sum_metric += count_truth
self.num_inst += count_pred
else:
self.sum_metric[i] += count_truth
self.num_inst[i] += count_pred
i += 1
def get(self):
if self.num is None:
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
else:
result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
return (self.name, result)