医学图像分割中的loss函数

本文综述了用于医学图像分割中的 loss 函数,可大致分为基于分布 (Distribution-based) 的损失函数、基于区域 (Region-based) 、基于边界 (Boundary-based)的损失函数、复合损失函数 4 种。

Distribution-based Loss

基于分布的损失函数旨在最小化两个分布之间的差异。这一类别中最基本的函数是交叉熵;其他所有的函数都是从交叉熵导出的。

Cross Entropy

理性认识(信息论)

如果我们对于同一个随机变量 $x$ 有两个单独的概率分布 $P(x)$ 和 $Q(x)$,我们可以使用 KL 散度(Kullback-Leibler (KL) divergence) 来衡量这两个分布的差异:

$$
D_{\mathrm{KL}}(P|Q) = E_{x \sim P} [ \log P(x) - \log Q(x)]
$$

如果对于数据 $x$ 有标签$g$服从真实分布$P(x)$,神经网络预测$s$服从估计分布$Q(x)$,则要使得网络预测越准即使分布 $P(x)$ 和分布 $Q(x)$ 之间的差距越小,即 $D_{\mathrm{KL}}(P|Q)$ 越小。从定义式中删去由数据标签决定的、无法优化的 $H(P)$ ,得到交叉熵函数

$$
H(P, Q) = -E_{x \sim P} \log Q(x)
$$

最小化$H(P, Q)$等同于最小化$D_{\mathrm{KL}}(P|Q)$,规定$\lim_{x→0} x log x = 0$

感性认识

假设我们需要解决 $n$ 个数据 ${x_1, \ldots, x_n}$ 的二分类问题。假设我们分别将$1$和$0$编码为正类标签和负类标签,用$y_i$表示,神经网络的参数为 $\theta$ 。使用最大似然估计法找到最优的 $\theta$ 使得 $\hat{y_i}=p_{\theta}(y_i \mid x_i)$ 。具体来说,对于真实标签 $y_i$ 和预测标签 $\hat{y_i}= p_{\theta}(y_i \mid x_i)$,被分类为正的概率是 $\pi_i= p_{\theta}(y_i = 1 \mid x_i)$。因此有:

$$
\begin{split}\begin{aligned}
l(\theta) &= \log L(\theta) \
&= \log \prod_{i=1}^n {s_i^c}^{g_i^c} (1 - s_i^c)^{1 - g_i^c} \
&= \sum_{i=1}^n g_i^c \log(s_i^c) + (1 - g_i^c) \log (1 - s_i^c). \
\end{aligned}\end{split}
$$

将上述损失函数推广到任何分布,我们也称为交叉熵损失,其中$g$服从真实分布$P$,$s$服从估计分布$Q$

$$
L_{C E}=-\frac{1}{N} \sum_{c=1}^C \sum_{i=1}^N g_i^c \log s_i^c,
$$

所以,交叉熵损失函数是最基础的、起源于 Kullback-Leibler (KL) 散度的衡量分布间不相似度的指标。

Weighted cross entropy (WCE)

为了解决样本不平衡的问题,提出加权交叉熵损失函数。$w_c$是加给每类的权重,可以设为与类出现的频率成反比的值以达到数据分布的平衡。

$$
L_{W C E}=-\frac{1}{N} \sum_{c=1}^c \sum_{i=1}^N w_c g_i^c \log s_i^c
$$

对于二分类问题,值 pos_weight>1 会减少假阴性计数,从而增加召回率。相反,将pos_weight设置为<1 会减少假阳性计数并提高精度。

1
labels * -log(sigmoid(logits)) * pos_weight + (1 - labels) * -log(1 - sigmoid(logits))

TopK loss

TopK loss 旨在迫使网络关注难分样本(hard samples)。通过设定一个 threshold 阈值 t,网络输出与 GT 的差值低于 t 时该样本不计入损失。它的一种表达式是:

$$
L_{T o p K-t h r}=-\frac{1}{\sum_{c=1}^C \sum_{i=1}^N 1 { g_i=c \text { and } s_i^c<t } } \sum_{c=1}^C \sum_{i=1}^N 1 { g_i=c \text { and } s_i^c<t } \log s_i^c
$$

其中,$1{. . . }$ 是二元指示函数。

它的另一种表达式是:

$$
L_{T o p K}=-\frac{1}{N} \sum_{c=1}^C \sum_{i \in \mathbf{K}} g_i^c \log s_i^c
$$

K 是集合中 k% 个最难分的像素点。

Focal loss

Focal loss 通过减少分配给分类良好的样本的损失,来实现对难分样本的关注,可用于处理 foreground-background class imbalance。

$$
L_{\text {Focal }}=-\frac{1}{N} \sum_c^C \sum_{i=1}^N\left(1-s_i^c\right)^\gamma g_i^c \log s_i^c
$$

$s_i$ 反映了与 ground truth 即类别 y 的接近程度, $s_i$ 越大说明越接近类别 y,即分类越准确。对于分类准确的样本  $p_t→1$ ,modulating factor 趋近于 0。对于分类不准确的样本  $1−p_t→1$ ,modulating factor 趋近于 1。即相比交叉熵损失,focal loss 对于分类不准确的样本,损失没有改变,对于分类准确的样本,损失会变小。 整体而言,相当于增加了分类不准确样本在损失函数中的权重。

  • 可调节因子 $\gamma >0$ 。
  • modulating factor $(1−s_i^c)^γ$
  • 当 $\gamma = 0$ 时,focal loss 就是交叉熵损失函数。
  • 原始论文中 $\gamma = 2$ 效果最好。

Distance map penalized cross entropy loss (DPCE)

$$
L_{D P C E}=-\frac{1}{N} \sum_{c=1}^c\left(1+D^c\right) \circ \sum_{i=1}^N g_i^c \log s_i^c,
$$

$D^c$为对 GT 做欧几里得距离变换(scipy.ndimage.morphology.distance_transform_edt

欧氏距离变换能将二值图转换为灰度图,是指对于一张二值图像(假定白色为前景色,黑色为背景色),将前景中的像素的值转化为该点到达最近的背景点的距离。多应用于图像的骨架提取。

Distribution-based 总结

基于分布 (Distribution-based) 的损失函数通过在

Region-based Loss

基于区域的损失函数旨在最小化 G 和 S 之间的 mismatch 或最大化重叠区域(overlap),如 Dice 损失函数。

Dice Loss

先看三个医学上常见的评价指标。
一是精确度 Precision,是所有预测为真的样本中,真实值也为真类的概率(查准率):

$$
\text{ Precision }=\frac{TP}{TP+FP}
$$

二是召回率 Recall 或者叫敏感性 Sensitivity,代表所有真实值也为真样本中,预测为正的概率,即查全率。

$$
\text{ Recall } = \text{ Sensitivity }=\frac{TP}{TP+FN}
$$

三是特异性 specificity,

$$
\text{ Specificity }=\frac{TN}{TN+FP}
$$

那 Accuracy 表征的是预测正确的样本比例。不过通常不用这个概念,主要是因为预测正确的负样本这个没有太大意义。

样本总数

举一个癌症的例子,
precision 高是说我只有有很大的把握得时候我才会说你得癌症,意思是,只要我说你得癌症,你基本上就是得了癌症。
recall 高是说我基本能找到所有得癌症的人,意思是,得癌症的人一定在我说得癌症的人当中。
这两个指标显然不能一高一低,需要同时兼顾。为了综合考虑这两个指标,我们有 Dice 系数,是 precision 和 recall 的调和平均值:

$$
\text { Dice similarity coefficient}=\text{ F1-score } = \frac{1}{ \frac{1}{\text{ precision }}+\frac{1}{\text{ recall }}} = \frac{2 T P}{2 T P+F P+F N}
$$

Dice 系数在 0 到 1 之间,故 Dice loss $L_{\text { Dice }}$ 为:

$$
L_{\text { Dice }}= 1-\text{ Dice } \\ = 1-\frac{2 \sum_{c=1}^C \sum_{i=1}^N g_i^c s_i^c}{\sum_{c=1}^C \sum_{i=1}^N g_i^c+\sum_{c=1}^C \sum_{i=1}^N s_i^c}
$$

Sensitivity-specificity loss

Sensitivity-specificity loss 强调特异性:

$$
\begin{aligned}
L_{S S}=& w*sensitivity+(1-w)*specificity \ =& w \frac{\sum_{c=1}^c \sum_{i=1}^N\left(g_i^c-s_i^c\right)^2 g_i^c}{\sum_{c=1}^C \sum_{i=1}^N g_i^c+\epsilon} \
&+(1-w) \frac{\sum_{c=1}^C \sum_{i=1}^N\left(g_i^c-s_i^c\right)^2\left(1-g_i^c\right)}{\sum_{c=1}^C \sum_{i=1}^N\left(1-g_i^c\right)+\epsilon}
\end{aligned}
$$

IoU (Jaccard) loss

$$
L_{I O U}=1-IoU=1-\frac{\sum_{c=1}^C \sum_{i=1}^N g_i^c s_i^c}{\sum_{c=1}^C \sum_{i=1}^N\left(g_i^c+s_i^c-g_i^c s_i^c\right)}
$$

$$
IoU=\frac{TP}{FP+TP+FN}
$$

Lovász loss

使用 Lovasz extension 将离散的 Jaccard loss 变成光滑的形式,从而可以直接进行求导(不稳定,作者推荐先使用 ce loss 再使用这个 finetune)

$$
m_i(c)= \begin{cases}1-s_i^c, & \text { if } c=g_i \ s_i^c, & \text { otherwise }\end{cases}
$$

$$
L_{\text {lovasz }}=\overline{\Delta J_c}(m(c))
$$

Tversky loss

对 Dice loss 的每一项都加上一个系数,得到 Tversky loss:

$$
\text { Tversky }=\frac{T P}{T P+\alpha * F P+\beta * F N}
$$

$$
\begin{aligned}
L_{\text {Tversky }}=& 1-T(\alpha, \beta) \
=& 1-\left(\sum_c^C \sum_{i=1}^N g_i^c s_i^c\right) /\left(\sum_c^C \sum_{i=1}^N g_i^c s_{\dot{i}}^c\right.\
&\left.+\alpha \sum_c^C \sum_{i=1}^N\left(1-g_i^c\right) s_i^c+\beta \sum_c^C \sum_{i=1}^N g_i^c\left(1-s_i^c\right)\right)
\end{aligned}
$$

Tversky loss 能更好地权衡精确性和召回率(FPs 与 FNs)。通过调整超参数 α 和 β,我们可以控制假阳性(False positives)和假阴性(False negatives)之间的权衡。值得注意的是,在 α=β=0.5 时 Tversky 指数为与骰子系数相同,也等价于 F1-score。当 α=β=1 时,等式 2 产生 Tanimoto 系数,设置 α+β=1 产生 Fβ scores。βs 越大,召回率越高(通过更加强调假阴性)。我们假设在我们的广义损失函数中使用更高的 βs 将导致对不平衡数据的更高的泛化和更好的性能,并有效地帮助我们将重点转移到降低 FNs 和提高召回率。

Asymmetric similarity loss

引入加权参数 $\beta$ 来更好地调整 FP 和 FNs 的权重

$$
\begin{aligned}
L_{A s y m}=&\left(\sum_{c=1}^C \sum_{i=1}^N g_i^c s_i^c\right) /\left(\sum_{c=1}^C \sum_{i=1}^N g_i^c s_i^c\right.\
&\left.+\frac{\beta^2}{1+\beta^2} \sum_{c=1}^C \sum_{i=1}^N g_i^c\left(1-s_i^c\right)+\frac{1}{1+\beta^2} \sum_{c=1}^C \sum_{i=1}^N\left(1-g_i^c\right) s_i^c\right)
\end{aligned}
$$

$$
\begin{equation*} F(P,G;\beta) = \frac {|PG|}{|PG|+\frac {\beta ^{2}}{(1+\beta ^{2})} |G\setminus P|+\frac {1}{(1+\beta ^{2})}|P\setminus G|}\quad \end{equation*}
$$

原始论文中推荐的 β 为 1.5;当 α+β=1 时,不对称相似性损失也是 Tversky 损失的一个特例

Focal Tversky loss

借助系数$γ$关注低概率的 hard cases

$$
L_{F T L}=\left(L_{\text {Tversky }}\right)^{\frac{1}{y}},
$$

$\gamma \in [1,3]$.

Generalized Dice loss

Generalized Dice loss 是 Dice loss 的多类别扩展。当病灶分割有多个类别时,一般针对每一类都会有一个 Dice,而 Generalized Dice index 将多个类别的 Dice 进行整合,使用一个指标对分割结果进行量化。它可以表示为:

$$
L_{G D}=1-2 \frac{\sum_{c=1}^C w_c \sum_{i=1}^N g_i^c s_i^c}{\sum_{c=1}^C w_c \sum_{i=1}^N\left(g_i^c+s_i^c\right)}
$$

Penalty loss

就是在 Generalized Dice loss 的基础上额外给 FN 和 FP 加上了系数$k$。

$L_{p G D}=1-p G D$ where the $p G D$ is defined by

$$
\begin{aligned}
&p G D=2\left(\sum_{c=1}^c w_c \sum_{i=1}^N g_i^c S_i^c\right) /\left(\sum_{c=1}^c w_c \sum_{i=1}^N\left(g_i^c+s_i^c\right)\right. \
&\left.+k \sum_{c=1}^c w_c \sum_{i=1}^N\left(1-g_i^c\right) s_i^c+k \sum_{c=1}^c w_c \sum_{i=1}^N g_i^c\left(1-s_i^c\right)\right) \
&
\end{aligned}
$$

当 $K=0$ 时 $pGD=GD$;当 $K \gt 0$ 时 $pGD$ 给假阳性 FP 和假阴性 FN 额外的权重。

证明:

$$
FN = \sum_n^p\left(1-G_{l n}\right) P_{l n}
$$

$$
FP = \sum_n^p G_{l n}\left(1-P_{l n}\right)
$$

$$
\begin{aligned}
p G D&=2 \frac{\sum_{l=1}^c w_l \sum_n^p G_{l n} P_{l n}}{\sum_{l=1}^c w_l \sum_n^p\left(G_{l n}+P_{l n}\right)+k \sum_{l=1}^c w_l \sum_n^p\left(1-G_{l n}\right) P_{l n}+k \sum_{l=1}^c w_l \sum_n^p G_{l n}\left(1-P_{l n}\right)} \
&=2 \frac{\sum_{l=1}^c w_l \sum_n^p G_{l n} P_{l n}}{\sum_{l=1}^c w_l \sum_n^p\left(G_{l n}+P_{l n}\right)+k \sum_{l=1}^c w_l \sum_n^p\left(\left(1-G_{l n}\right) P_{l n}+G_{l n}\left(1-P_{l n}\right)\right)} \
&=2 \frac{\sum_{l=1}^c w_l \sum_n^p G_{l n} P_{l n}}{\sum_{l=1}^c w_l \sum_n^p\left(G_{l n}+P_{l n}\right)+k \sum_{l=1}^c w_l \sum_n^p\left(P_{l n}-2 P_{l n} G_{l n}+G_{l n}\right)} \
&=\frac{G D}{1+k(1-G D)} \
&
\end{aligned}
$$

Region-based 总结

基于区域的损失函数以 TP、TN、FP、FN 为单位出发,Dice

Boundary-based Loss

最小化 GT 和预测之间的距离

分割里常用的 Cross-entropy 和 Dice loss 对 class-imbalance 比较敏感,训练出来的网络会 bias 到背景。
将 cross-entropy 的积分区域拆开为前景和背景,可以发现在反向传播的时候,背景的 loss 要占主要部分。
regional loss 的梯度对背景像素一视同仁,忽略了空间的信息,比如离 ground truth 距离远的误分割应该给较大的惩罚权重。

Boundary (BD) loss

如果 prediction 和 label 一致,loss 为 0。如果 prediction 比 label 小并被 label 包围,loss 为负。

$$L_{B D}=\sum_{\Omega} \phi_G(p) s_\theta(p)$$

如果 $q \in G $,则$\phiG(q) = −D_G(q)$,否则$\phi_G(q) = D_G(q)$。
$S
\theta $代表网络的 softmax 输出。

区域表示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from scipy.ndimage import distance_transform_edt as distance

# 二值化
posmask = img_gt[b][c].astype(np.bool)
negmask = ~posmask

# 求出 distance map
posdis = distance_transform_edt(posmask)
negdis = distance_transform_edt(negmask)

# 找出边界
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)

# 两个减一下就是 phi 了
sdf = negdis - posdis

# 边界的distance为0
sdf[boundary==1] = 0
phi[b][c] = sdf

# Hadamard 积
multipled = torch.einsum("bcxyz,bcxyz->bcxyz", net_output[:, 1:, ...], phi[:, 1:, ...])

bd_loss = multipled.mean()

如果将边界损失看成纯粹是基于距离的惩罚,目标内部的梯度将为负。 在梯度下降过程中提高目标类别的概率。 相反,背景像素具有正梯度,从而在 SGD 期间将预测的概率降低。如果距离图全为正,它将降低所有像素(前景或背景)的预测概率。

Hausdorff Distance (HD) loss

豪斯多夫距离量度度量空间中真子集之间的距离:

$$hd(X,Y)=max_{x \epsilon X}min_{y \epsilon Y} ||x-y||_2$$

$$hd(Y,X)=max_{y \epsilon Y}min_{x \epsilon X} ||x-y||_2$$

$$HD(X,Y)=max(hd(X,Y), hd(Y,X))$$

优化下式即优化 distance:

$$
L_{H D}=\frac{1}{N} \sum_{c=1}^C \sum_{i=1}^N\left[\left(s_i^c-g_i^c\right)^2 \circ\left(d_{G_i^c}^2+d_{S_i^c}^2\right)\right]
$$

Compound Loss

Combo loss

$$L_{DiceCE} = \alpha L_{CE} + (1-\alpha)L_{Dice}$$

组合损失[15]定义为 Dice 损失和修改的交叉熵的加权和。它试图利用 Dice 损失类不平衡的灵活性,同时使用交叉熵进行曲线平滑。

Exponential Logarithmic loss (ELL)

提出对 Dice 损失和交叉熵损失进行指数和对数操作从而将更多的注意力集中在预测精度较低的结构上,可用来预测不确定的结构。

$$
\begin{aligned}
L_{E L L}=w_{\text {Dice }} E\left[\left(-\log \left(\text { Dice }c\right)\right)^{\gamma \text { Dice }}\right]
&+w
{C E} E\left[w_c\left(-\log \left(s_i^c\right)\right)^{\gamma C E}\right] \
\text { Where } \text { Dice }c=\frac{2 \sum{i=1}^N g_i^c s_i^c+\epsilon}{\sum_{i=1}^N\left(g_i^c+s_i^c\right)+\epsilon} \text {. }
\end{aligned}
$$

Dice loss with focal loss

$$L_{DiceFocal} = L_{Dice} + L_{Focal}$$

Dice loss with TopK loss

$$L_{DiceTopK} = L_{Dice} + L_{TopK}$$

Tailored loss

  • 针对特殊任务设计,用于通过探索标签关系改进多类分割
  • 形状先验和约束也被合并到损失函数中,以利用期望的拓扑结构来实现分割结果。
  • 为了提高网络的区域标记一致性,(Ganaye 等人,2019)提出了一种基于邻接图的辅助训练损失,该损失可以惩罚包含具有解剖学上不正确邻接关系的区域的输出。
  • 设计了一个连续值损失函数,该函数可以强制分割,使其具有与 GT 相同的 Betti 数。

附:常用指标的理解

  • TP:True Positive,分类器预测结果为正样本,实际也为正样本,即正样本被正确识别的数量。
  • FP:False Positive,分类器预测结果为正样本,实际为负样本,即误报的负样本数量。
  • TN:True Negative,分类器预测结果为负样本,实际为负样本,即负样本被正确识别的数量。
  • FN:False Negative,分类器预测结果为负样本,实际为正样本,即漏报的正样本数量。

组合起来就是:

  • TP+FN:真实正样本的总和,正确分类的正样本数量+漏报的正样本数量。
  • FP+TN:真实负样本的总和,负样本被误识别为正样本数量+正确分类的负样本数量。
  • TP+TN:正确分类的样本总和,正确分类的正样本数量+正确分类的负样本数量。

Accuracy:准确率

Accuracy 表征的是预测正确的样本比例。不过通常不用这个概念,主要是因为预测正确的负样本这个没有太大意义。

样本总数

Precision:查准率

Precision 表征的是预测正确的正样本的准确度,查准率等于预测正确的正样本数量/所有预测为正样本数量。Precision 越大说明误检的越少,Precision 越小说明误检的越多。

Recall:查全率
Recall 表征的是预测正确的正样本的覆盖率,查全率等于预测正确的正样本数量/所有正样本的总和,TP+TN 实际就是 Ground Truth 的数量。Recall 越大说明漏检的越少,Recall 越小说明漏检的越多。

References

  1. MA J, CHEN J, NG M, et al. Loss odyssey in medical image segmentation[J/OL]. Medical Image Analysis, 2021, 71: 102035. DOI:10.1016/j.media.2021.102035.
  2. 医学影像分割中常用的损失函数 - 知乎