MXNet学习笔记——5 multi-task任务实战

    xiaoxiao2022-06-30  147

    写在前面

    本阶段目标

    具体笔记

    Multi-task和multi-label的区别

    MXnet下定义multi-task的网络结构

    定义Multi-task评价指标metric

    单任务的accuracy函数:

    官方的用于图像多标签(multi-label)分类的multi_accuracy函数:

    cross-rentropy

    recall

    precision


    写在前面

    本系列博客记录了作者上手MXNet的全过程。作者在接触MXNet之前主要使用keras,和一点tensorflow,因此在上手MXNet之前有一点deep learning的项目基础。主要参考资料为MXNet官方教程,也阅读了一些有价值的博客。

    博客结构为:先列出作者对于该阶段的期望目标,以及各目标完成过程中的笔记(仅记下个人认为重要的),再附上学习过程中自己的提问(solved & unsolved,天马行空的提问,欢迎讨论)。


    本阶段目标

    任务优先级预计花时间完成状态遇到问题补充定义Multi-task数据格式P02hour  定义Multi-task网络P10.5hour  定义Multi-task评价指标metricP21.5hour  网络训练以及评估     

    具体笔记

    Multi-task和multi-label的区别

    multi-task 比multi-label更复杂,网络的中间过程可以有分支

    multi-label是特殊的multi-task。当每个task的分类取值都是二分类时,就是multi-label,但multi-task的每个任务可以是多分类

    MXnet下定义multi-task的网络结构

    代码:

    图示: 网络结构在flatten0后出现了分支

    定义Multi-task评价指标metric

    网络上关于multi-task的metric资料很多,但基本都是multi_accuracy,在此整理了accuracy / cross-entropy / precision / recall 的单(多)任务版本。

    单任务的accuracy函数:

    import mxnet as mx class Accuracy(mx.metric.EvalMetric): def __init__(self, num=None): super(Accuracy, self).__init__('accuracy', num) def update(self, labels, preds): pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32') label = labels[0].asnumpy().astype('int32') mx.metric.check_label_shapes(label, pred_label) self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat)

    官方的用于图像多标签(multi-label)分类的multi_accuracy函数:

    class Multi_Accuracy(mx.metric.EvalMetric): """Calculate accuracies of multi label""" def __init__(self, num=None): self.num = num super(Multi_Accuracy, self).__init__('multi-accuracy') def reset(self): """Resets the internal evaluation result to initial state.""" self.num_inst = 0 if self.num is None else [0] * self.num self.sum_metric = 0.0 if self.num is None else [0.0] * self.num def update(self, labels, preds): mx.metric.check_label_shapes(labels, preds) if self.num is not None: assert len(labels) == self.num for i in range(len(labels)): pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') label = labels[i].asnumpy().astype('int32') mx.metric.check_label_shapes(label, pred_label) if self.num is None: self.sum_metric += (pred_label.flat == label.flat).sum() self.num_inst += len(pred_label.flat) else: self.sum_metric[i] += (pred_label.flat == label.flat).sum() self.num_inst[i] += len(pred_label.flat) def get(self): """Gets the current evaluation result. Returns ------- names : list of str Name of the metrics. values : list of float Value of the evaluations. """ if self.num is None: return super(Multi_Accuracy, self).get() else: return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0 else self.sum_metric[i] / self.num_inst[i]) for i in range(self.num))) def get_name_value(self): """Returns zipped name and value pairs. Returns ------- list of tuples A (name, value) tuple list. """ if self.num is None: return super(Multi_Accuracy, self).get_name_value() name, value = self.get() return list(zip(name, value))

    调用时,修改Multi_Accuracy(num=3)的参数num,就可以指定计算出几个accuracy。

    from my_metric import * eval_metric = mx.metric.CompositeEvalMetric() eval_metric.add(Multi_Accuracy(num=2))

    以下的 cross-entropy / recall 和 precision 的metric函数,均可通过修改num和name来指定用于单任务还是多任务。

    cross-rentropy

    class CrossEntropy(mx.metric.EvalMetric): def __init__(self, eps=1e-12, name='cross-entropy', output_names=None, label_names=None, num=None): super(CrossEntropy, self).__init__( name, eps=eps, output_names=output_names, label_names=label_names) self.eps = eps self.num = num self.name = name self.reset() def reset(self): if getattr(self, 'num', None) is None: self.num_inst = 0 self.sum_metric = 0.0 else: self.num_inst = [0] * self.num self.sum_metric = [0.0] * self.num def update(self, labels, preds): mx.metric.check_label_shapes(labels, preds) i = 0 for label, pred in zip(labels, preds): label = label.asnumpy() pred = pred.asnumpy() label = label.ravel() assert label.shape[0] == pred.shape[0] if i == 1: sexy_index = np.where(np.int64(label) == -1) label[sexy_index] = 0.0 # random 0 or 1 pred[sexy_index] = np.ones((len(sexy_index),2)) # No loss for sexy image prob = pred[np.arange(label.shape[0]), np.int64(label)] if self.num is None: self.sum_metric += (-np.log(prob + self.eps)).sum() if i == 1: self.num_inst += (label.shape[0] - len(sexy_index[0])) else: self.num_inst += label.shape[0] else: self.sum_metric[i] += (-np.log(prob + self.eps)).sum() if i == 1: self.num_inst[i] += (label.shape[0] - len(sexy_index[0])) else: self.num_inst[i] += label.shape[0] i += 1 def get(self): if self.num is None: if self.num_inst == 0: return (self.name, float('nan')) else: return (self.name, self.sum_metric / self.num_inst) else: result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)] return (self.name, result)

    recall

    class Recall(mx.metric.EvalMetric): def __init__(self, name, num=None): super(Recall, self).__init__('Recall') self.num = num self.name = name self.reset() def reset(self): if getattr(self, 'num', None) is None: self.num_inst = 0 self.sum_metric = 0.0 else: self.num_inst = [0] * self.num self.sum_metric = [0.0] * self.num def update(self, labels, preds): mx.metric.check_label_shapes(labels, preds) i = 0 for pred, label in zip(preds, labels): pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32') label = label.asnumpy().astype('int32') count_pred = 0 count_truth = 0 for index in range(len(pred.flat)): if label[index] == -1: continue if pred[index] == 0 and label[index] == 0: count_pred += 1 if label[index] == 0: count_truth += 1 if self.num is None: self.sum_metric += count_pred self.num_inst += count_truth else: self.sum_metric[i] += count_pred self.num_inst[i] += count_truth i += 1 def get(self): if self.num is None: if self.num_inst == 0: return (self.name, float('nan')) else: return (self.name, self.sum_metric / self.num_inst) else: result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)] return (self.name, result)

    precision

    class Precision(mx.metric.EvalMetric): def __init__(self, name, num=None): super(Precision, self).__init__('Precision') self.num = num self.name = name self.reset() def reset(self): if getattr(self, 'num', None) is None: self.num_inst = 0 self.sum_metric = 0.0 else: self.num_inst = [0] * self.num self.sum_metric = [0.0] * self.num def update(self, labels, preds): mx.metric.check_label_shapes(labels, preds) i = 0 for pred, label in zip(preds, labels): pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32') label = label.asnumpy().astype('int32') count_truth = 0 count_pred = 0 for index in range(len(pred.flat)): if label[index] == -1: continue if pred[index] == 0 and label[index] == 0: count_truth +=1 if pred[index] ==0: count_pred +=1 if self.num is None: self.sum_metric += count_truth self.num_inst += count_pred else: self.sum_metric[i] += count_truth self.num_inst[i] += count_pred i += 1 def get(self): if self.num is None: if self.num_inst == 0: return (self.name, float('nan')) else: return (self.name, self.sum_metric / self.num_inst) else: result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)] return (self.name, result)

     

     

     

     


    最新回复(0)