SmoothL1 loss

    xiaoxiao2022-07-05  179

     

    实现代码如下:

    def smooth_l1_loss(input, target, sigma, reduce=True, normalizer=1.0): beta = 1. / (sigma ** 2) diff = torch.abs(input - target) cond = diff < beta loss = torch.where(cond, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) if reduce: return torch.sum(loss) / normalizer return torch.sum(loss, dim=1) / normalizer

     

    最新回复(0)