[๋ ผ๋ฌธ์ ๋ฆฌ๐] Pseudo-Labeling and Confirmation Bias in DeepSemi-Supervised Learning
Pseudo-Labeling and Confirmation Bias in Deep Semi-Supervised Learning
- Pseudo label -
๋ ผ๋ฌธ์๋ณธ๐
์ด ๋ ผ๋ฌธ์ semi supervised learning(SSL)์ ์ธก๋ฉด์์ pseudo label์ ํ๋ ์ด์ ์ pseudo label๋ฅผ ์ด๋ป๊ฒ ๊ตฌ์ฑํ์ฌ ์ฑ๋ฅ์ ์ฌ๋ฆด ์ ์์๋์ง์ ๋ํด ์ค๋ช ํ๊ณ ์์ต๋๋ค.
์ ๋ ์ด ๋ ผ๋ฌธ์ ๋ํด์๋ ์์ธํ๋ณด๋ค๋ pseudo label์ด ์ด๋ป๊ฒ ์ด๋ฃจ์ด ์ง๋์ง๋ฅผ loss function๊ณผ ๊ตฌ์กฐ์ ์ธก๋ฉด์์ ํ์ธํด ๋ณด๋ฉฐ ์ด๋ค์์ผ๋ก ์ฑ๋ฅํฅ์์ ๋์์ด ๋๋์ง์ ๋ํด์๋ง ์ ๋ฆฌํ๊ณ ์ ํฉ๋๋ค!โญ๏ธ
๋จผ์ , ๊ฐ๋ตํ๊ฒ pseudo label ์ ํ๋ ์ด์ ์ ๋ํด ์ ์ด๋ณด์๋ฉด,
image classification ๋ถ์ผ์ ์์ด์ ๋ง์ ๋ฐ์ดํฐ์ ์ด ๋ ์ข์ ์ฑ๋ฅ์ ๋ด๋ ๊ฒ์ ๋ณด๊ณ , ๊ฐ๋ฐ์๋ค์ ๋ ๋ง์ ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ต์ํค๊ณ ์ ํ์์ต๋๋ค.
๊ทธ๋์ ๋์จ ๋ฐฉ๋ฒ์ด, label์ด ๋์ง ์์ ๋ ๋ง์ ๋ฐ์ดํฐ์ ์ ๋์ด์์ ์์๋ก labeling์ ํด์ค ํ ํ์ต์ ๋ฐ์ดํฐ๋ฅผ ์ถ๊ฐํ๋ ๋ฐฉ์์ ๋๋ค. ์ด ๋ฐฉ์์ ์ง๋ํ์ต๊ณผ ๋น์ง๋ ํ์ต์ ์์ ํ์์ด๋ฏ๋ก semi-supervised learning(SSL)์ด๋ผ๊ณ ํฉ๋๋ค.
์ฌ๊ธฐ์ ์์๋ก labeling ํ๋ ๊ฒ์, ๋จผ์ ๋ชจ๋ธ์ imagenet ๋ฑ์ ๋ ์ด๋ธ์ด ์๋ ๋ฐ์ดํฐ์
์ผ๋ก ํ๋ฒ ํ์ต์ ์์ผ์ค ํ, ๊ทธ ๋ชจ๋ธ๋ก label์ด ์๋ ๋ฐ์ดํฐ๋ฅผ ์์๋ก labeling์ ์งํํฉ๋๋ค. ์ด๋ ์๊ธด ์์์ label์ pseudo label
์ด๋ผ๊ณ ๋ช
๋ช
ํฉ๋๋ค.
pseudo label๋ก ์ด๋ฃจ์ด์ง ๋ฐ์ดํฐ๋ฅผ ์ถ๊ฐํ์ฌ ํ์ตํ ๊ฒฐ๊ณผ๋ ์ข์ SOTA์ฑ๋ฅ์ ๋ณด์ธ๋ค๋๋ฐ!! ์ฌ๊ธฐ์ ๋ ์๋ฌธ์ ๊ทธ๋ผ pseudo label์์ ์๋ชป๋ label์ด ์์ฑ๋์ด ํ์ต๋ ๊ฒฝ์ฐ
๋ ์ ์ข์ ๊ฒฐ๊ณผ๋ฅผ ๋ผ ํ
๋ฐ.. ์ด๋ค๋ฐฉ์์ผ๋ก labeling ํ์ต์ ํ๋์ง๊ฐ ๊ถ๊ธ
๐ฒ๐ฒํ์ฌ ์ด ๋
ผ๋ฌธ์ ์ฝ๊ฒ ๋์์ต๋๋ค.
์ด์ ์ด๋ป๊ฒ labeling ์ ์ ํ๋๋ก ํ์ต์์ผฐ๋์ง ์์๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
Pseudo label ์ข ๋ฅ
๐ hard-pseudo-label
๋คํธ์ํฌ์ ์์ธก๊ฐ์ ๋ผ๋ฒจ๋ก ์ฌ์ฉํ๋ ๋ฐฉ์์ผ๋ก one-hot vector์ ์๊ฐํ๋ฉด ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด ์ด๋ค ์ด๋ฏธ์ง์ ๋ํ ๋ ์ด๋ธ ๊ฐ์ด ๊ณ ์์ด! ์ด๋ ๊ฒ ๋์ค๋ฉด ๊ทธ ์ด๋ฏธ์ง๋ ๊ทธ๋๋ก ๊ณ ์์ด๋ก ๋ ์ด๋ธ ๋ฉ๋๋ค.
๐ soft-pseudo-label
๋ฐ๋ฉด soft-label ๋ฐฉ์์ softmax prediction๊ฐ์ ์ฌ์ฉํฉ๋๋ค. ์ฆ, continuous distribution ํ label ์ ๋ปํ๋ฉฐ ๊ฐ ํด๋์ค๋ก ์์ธก๋ ํ๋ฅ ๊ฐ์ด ๋ค์ด๊ฐ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ์ด๋ค ์ด๋ฏธ์ง์ ๋ํด์ ๊ณ ์์ด์ผ ํ๋ฅ 70% ํธ๋์ด์ผ ํ๋ฅ 20% ๊ฐ์์ง์ผ ํ๋ฅ 10% ์ด๋ฐ์์ผ๋ก ๋์ค๊ฒ ๋๋ ๊ฒ์ ๋๋ค.
์ด ๋ ผ๋ฌธ์์๋ soft ๋ฐฉ์์ ์ฌ์ฉํ๋ฉฐ ์ด ์ ์ ๋ฆฌ๋ทฐํ๋ noisy student ๋ชจ๋ธ์์๋ softํ ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค.
Pseudo label์ ์์คํจ์(loss function)
CNN ํ๋ผ๋ฏธํฐ $\theta$๋ categorical cross-entropy ๊ณต์์ผ๋ก optimize ๋ฉ๋๋ค. ๊ทธ ์์์ ์๋์ ๊ฐ๋ค.
$l^*(\theta) = -\sum_{i=1}^N \tilde{y}^T_ilog(h_\theta(x_i))$
$h_\theta(x)$๋ softmax ํจ์๋ฅผ ๊ฑฐ์ณ ๋์จ ํ๋ฅ ๊ฐ์ ์๋ฏธํ๋ฉฐ ์ฌ๊ธฐ์ $log$๋ฅผ ์ทจํ๋ ๊ฒ์ element-wise ํ๊ธฐ ์ํจ์ ๋๋ค.
๋ค์ ์์ธํ ๋ฏ์ด ๋ณด์๋ฉด
-
unlabel๋ sample : $N_u$
-
unlabel set : $D_u = \lbrace x_i \rbrace^{N_u}_{i=1}$
-
labeled set : $D_l = \lbrace ( x_i,y_i ) \rbrace^{N_l}_{i=1}$
๊ทธ๋ฆฌ๊ณ , one-hot encoding์ ์ํด $y_i$๋ $C$ ํด๋์ค๋ค์ ๋ชจ๋ ๋ฐ์ดํฐ ($N = N_l + N_u$)์ ์ํซ ์ธ์ฝ๋ฉ ํด์ฃผ๋ฉฐ, ๊ทธ ์์์ $y_i={\lbrace 0,1 \rbrace}^C$, ์ด๋ ๊ฒ ํํํ ์ ์์ต๋๋ค.
๋, pseudo label ๋ ๋ฐ์ดํฐ๋ฅผ ๋ค์๊ณผ ๊ฐ์ด ํํํ๋ฉฐ,
$\tilde{D} = \lbrace ( x_i, \tilde{y_i} ) \rbrace^N_{i=1}$
์ด๊ฒ์ ๊ฐ์ง๊ณ ๋ ์ด๋ธ๋ ์ํ๋ค์ธ $N_l$์ ์์ด์ $\tilde{y}=y$ ๊ฐ ๋ ์ ์๋ ๊ฒ์ ๋๋ค.
์ฌ๊ธฐ์ ํต์ฌ์ ์ด๋ป๊ฒ ๋ ์ด๋ธ ๋์ง ์์ ์ํ๋ค ($N_u$) ๋ก๋ถํฐ pseudo-labels ($\tilde{y}$)๋ฅผ ๋ง๋ค์ด ๋ด๋ ๊ฒ์ธ๋ฐ์..!
์ด์ ์ฐ๊ตฌ๋ one-hot encoding์ ์ฌ์ฉํ๋ hard ๋ฐฉ์์ ์ฌ์ฉํ์ต๋๋ค. ํ์ง๋ง softmax๋ฅผ ์ฌ์ฉํ๋ soft ๋ฐฉ์์ด ๋ ์ข์ ์ฑ๋ฅ์ ๋ด๋ ๊ฒ์ ๋ฐ๊ฒฌํ์๊ณ , ์ด์ ๋ฐฉ์์ soft-pseudo labeling ํ๋ ๋ฐฉ์์ ์ ์ฉํ์ฌ ์ฌ์ฉํ์๋ค๊ณ ํฉ๋๋ค. ๋ํ, ๋ ๊ฐ์ง ์ ๊ทํ
๋ฅผ ์ถ๊ฐ์ ์ผ๋ก ์ฌ์ฉํ์ฌ pseudo label์ด ๋ ์ ๋๋๋ก ๋ง๋ค์๋ค๊ณ ํฉ๋๋ค.
๋ ๊ฐ์ง ์ ๊ทํ๋ฅผ ์ ์ฉํ์ฌ ๋์จ ์ต์ข loss ์์์ ์ผ๋จ ๋ค์๊ณผ ๊ฐ์ต๋๋ค. ํ๋ํ๋ ๋ฐ์ ธ๋ณด๊ธฐ ์ ์ ์ ์ฒด์ ์ธ ๊ทธ๋ฆผ์ ๋ณด๊ณ ์ ํจ์ ๋๋ค.
$l=l^*+\lambda_AR_A+\lambda_HR_H$
๊ทธ๋ผ ๋ ์ ๊ทํ๋ฅผ ์ดํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
๐ ์ฒซ ๋ฒ์งธ ์ ๊ทํ
pseudo label์ ์์ฑํ๊ธฐ ์์ํ๋ ํ์ต ์ด๊ธฐ์๋ ๊ฑฐ์ ๋ถ์ ํํ ๊ฒฐ๊ณผ๋ฅผ ๋ธ๋ค๊ณ ํฉ๋๋ค. ๊ทธ ์ด์ ๋ CNN์ loss๋ฅผ ์ค์ด๊ธฐ ์ํด ๊ฐ์ ํด๋์ค๋ก ์์ธกํด๋ฒ๋ฆฌ๋ ๊ฒฝํฅ์ด ์๊ธฐ ๋๋ฌธ์ ๋๋ค. ์๋ฅผ ๋ค์ด ๋ชจ๋ ๋ฐ์ดํฐ์ ๋ํด์ ์ ๊ฐ๊ฐ์ธ ํด๋์ค๋ก ๋ถ๋ฅํ๊ธฐ ๋ณด๋ค๋ ๊ฐ์ ํด๋์ค๋ก ์ฃผ๋ ๊ฒ์ด loss๊ฐ ๋ ์ ๊ฒ ๋์ค๊ธฐ ๋๋ฌธ์ ๋๋ค. ์ํ์ ๋ณผ ๋ ๋๋คํ๊ฒ ์ฐ๋ ๊ฒ ๋ณด๋จ, ๊ฐ์ ๋ต์ผ๋ก ์ค์ ์ธ์ ์ฐ์ผ๋ฉด ๋ ์๋ง๋ ๊ฒ๊ณผ ๊ฐ์ ์๋ฆฌ๊ฒ ์ง์?
๋ฐ๋ผ์ ์ด ๋ ผ๋ฌธ์์๋ ์๋ ๊ณต์์ ์ถ๊ฐํ์ฌ ๊ฐ ํด๋์ค๋ค์ ๋ชจ๋ ์ํ๋ค์ ์ํฅ๋ ฅ์ ์๊ฒ ํฉ๋๋ค.
$R_A=\sum_{c=1}^Cp_clog({p_c \over \bar{h_c}})$
$p_c$๋ ์ด์ class $c$์ ๋ํ ์ด์ ํ๋ฅ ๋ถํฌ, $\bar{h}_c$๋ class $c$๋ฅผ dataset์ ๋ชจ๋ ์ํ๋ค์ ๋ํ softmax ํ๋ฅ ๊ฐ๋ค์ ํ๊ท ์ ๋๋ค.
๋ฐ๋ผ์ $p_c = {1 \over C}$ ์ ๋๋ค.
๋ฐ๋ผ์ ์ด์ ์ ์์ธกํ ๊ฐ์ ์ ์ฒด ํด๋์ค ๋ถ์ ์์ธกํ ๊ฐ์ ๋ก๊ทธ๋ฅผ ์ทจํ ๊ฐ์ ๊ณฑํด์ ๊ฐ์ ์๊ฒ ์ ๋ฐ์ดํธ๋ฅผ ์์ผ์ฃผ๊ฒ ๋๋ ๊ฒ ๊ฐ์ต๋๋ค.
๐ ๋ ๋ฒ์งธ ์ ๊ทํ
๋ค์ ์ ๊ทํ๋ ์ฝํ ๊ฐ์ด๋์ค(๋ถ์ ํํ ๊ฐ๋ค) ๋๋ฌธ์ local minima์ ๋น ์ง ๊ฒ์ ์ผ๋ คํด ๊ฐ๋ณ class์ ๋ํ soft-pseudo-label์ ๊ฐ ํ๋ฅ ๋ถํฌ์ ์ง์ค
ํ๋๋ก ํ๋ ๋ฐฉ๋ฒ์ ์ถ๊ฐํ๊ฒ ๋ฉ๋๋ค.
์ ๊ทํ ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
$R_H=-{1\over N}\sum_{i=1}^N\sum_{c=1}^Ch_\theta^c(x_i)log(h_\theta^c(x_i))$
์ดํด๋ณด๋ฉด, $h_\theta^c(x_i)$๋ softmax์ output์ธ $h_\theta(x)$์ c class value๋ฅผ ์๋ฏธํ๋ฉฐ, ์ด๊ฒ์ผ๋ก entropy๋ฅผ ๊ตฌํ๋ ๊ณต์์ ์ทจํด์ฃผ์ด์ ๊ฐ ์ํ๋ค์ ๋ํ ์ํธ๋กํผ๊ฐ์ ํ๊ท ์ ๊ตฌํ๊ฒ ๋ฉ๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ด๋ ๊ฒ ๋์จ ์ํธ๋กํผ๋ค์ ๋ง์ด๋์ค ๊ฐ์ ๊ฐ์ง๋ฏ๋ก ๋งจ ์์ ๋ง์ด๋์ค๋ฅผ ํ๋ฒ ๋ ์ทจํด์ฃผ์ด ์์๋ก ๋ง๋ค๊ฒ ๋ฉ๋๋ค.
๋ฐ๋ผ์ ์ด ๊ฒ์ผ๋ก ๋์ค๋ ๊ฐ์ ๋ปํ ์์๋๋ ๊ฐ(์์ธก์ด ์ฌ์ด ๊ฐ)์ผ ์๋ก ์์๊ฐ, ๊ฒฐ๊ณผ ์์ธก์ด ํ๋ค์๋ก ํฐ ๊ฐ์ ๋์ถํ๊ฒ ๋ฉ๋๋ค. ๊ทธ๋ฌ๋ฉด ๊ฒฐ๊ณผ ์์ธก์ด ํ๋ค์๋ก ์ ์ฒด์ loss๊ฐ ์ฌ๋ผ๊ฐ๊ฒ ๋๋ ๊ฑฐ๊ฒ ์ฃ !
๋ฐ๋ผ์ ์ด ๋ ์ ๊ทํ๋ฅผ ํฉ์น ์ ์ฒด์ ์ธ loss ์์์ด ์๋์ ๊ฐ์ด ๋๋ ๊ฒ์ ๋๋ค.
$l=l^*+\lambda_AR_A+\lambda_HR_H$
์ฌ๊ธฐ์ $l^*$์, ์ ์ผ hard-pseudo-label๋ฐฉ์์ softmax๋ฅผ ์ถ๊ฐํ ์ด๊ธฐ ๊ณต์์ผ๋ก,
$l^*(\theta) = -\sum_{i=1}^N \tilde{y}^T_ilog(h_\theta(x_i))$
๋ค์ ํ์ด์ ์ฐ๋ฉด, ๋ค์๊ณผ ๊ฐ์ ์์ฃผ ๊ธด ํจ์๊ฐ ๋๋ ๊ฒ์ ๋๋ค.
$l=-\sum_{i=1}^N \tilde{y}^T_ilog(h_\theta(x_i))+\lambda_A \sum_{c=1}^Cp_clog({p_c \over \bar{h_c}})+ -\lambda_H{1\over N}\sum_{i=1}^N\sum_{c=1}^Ch_\theta^c(x_i)log(h_\theta^c(x_i))$
๊ทธ๋ฐ๋ฐ ์ฌ๊ธฐ์ ๋์ด ์๋๋๋ค..ใ
ใ
confirmation bias
์ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด mixup ์ด๋ผ๋ ๊ฐ๋
์ ๋ ์ถ๊ฐํ๊ฒ ๋์๋๋ฐ์,
๐ Confirmation bias (ํ์ฆ ํธํฅ)
mixup
ํ๋ฆฐ pseudo-label๋ก ์ค๋ฒํผํ ๋๋ ๊ฒ์ ํ์ฆํธํฅ ์ด๋ผ๊ณ ํฉ๋๋ค. ๋, pseudo-label์ ํ๋ฉด์ ์๋ชป๋ ๋ ์ด๋ธ๋ก ํ์ต์ ๊ณ์ ํ๊ฒ ๋๋ ๋๋ ๋ง๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํด mixup์ด๋ผ๋ ๊ฐ๋ ์ ๋์ ํ๊ฒ ๋ฉ๋๋ค.
์ด ๊ฐ๋
์, ์์ ์ ์ธ ๋ชจ๋ธ์ด๋ผ๋ฉด ํน์ ๋ฒกํฐ์ ์ ํ๊ฒฐํฉ์ ๋ํ ์์ธก๊ฐ์ด ๋ ์ด๋ธ์ ์ ํ๊ฒฐํฉ๋ฐฉ์์ด ๋์ด์ผ ํ๋ค
๋ ๊ฐ๋
์์ ๋์ค๊ฒ ๋ฉ๋๋ค. ์์์ผ๋ก ์ค๋ช
๋๋ฆฌ์๋ฉด,
randomํ $(x_p, y_p), (x_q, y_q)$์ ๋ํด์
$x = \delta x_p + (1-\delta)x_q $
$y = \delta y_p + (1-\delta)y_q $
๊ฐ ์ฑ๋ฆฝ๋๋ค๋ ๊ฒ์ ๋๋ค.
ํน์ ์ ๋ ฅ ๋ฒกํฐ๋ค์ ์ ํ๊ฒฐํฉํ ๋ฒกํฐ $x$๋ ๊ทธ ๋ ์ด๋ธ $y$ ๋ํ ๋๊ฐ์ ๋ฐฉ์์ผ๋ก ์ ํ๊ฒฐํฉํ์ ๋, ๊ทธ ๊ฒฐ๊ณผ ๋ํ ๊ฐ์ด ๋งค์นญ๋์ด์ผ ํ๋ค๋ ๊ฒ์ ๋๋ค!
์ด๋ฐ์์ผ๋ก ๋ ์ด๋ธ์ด ์๋ ๋ฐ์ดํฐ๋ค์ ๋ํ mixup ๋ชจ๋ธ์ ์ด์ฉํด ๋ ์ด๋ธ์ ์์ธกํ ํ ์ด๋ฅผ ์ด์ฉํด mixup์ ์งํํ๋ ๋ฐฉ์์ผ๋ก ์ด๋ฃจ์ด ์ง๋๋ค.
์์ ์์์
$l^* = \delta$$l^โ_p$ $+ (1-ฮด)$ $l^โ_p$
๋ก ๋ ์ ์๊ณ ๋ฐ๋ผ์ loss $l^*$์ ๋ํด ์ฌ์ ์ ํ๋ฉด
$l^* = -\sum_{i=1}^N\delta \lbrack\tilde{y}^T_{i,p}log(h_\theta(x_i))\rbrack+(1-\delta)\lbrack\tilde{y}^T_{i,q}log(h_\theta(x_i))\rbrack$ ๊ฐ ๋ฉ๋๋ค.
๋ฐ๋ผ์ ๋ค์์ ์ต์ข ์์
$l=l^*+\lambda_AR_A+\lambda_HR_H$
์์ $l^{*}$๋ง ๋ฐ๋๊ฒ ๋๊ฒ ์ฃ !
์ฌ๊ธฐ๊น์ง ๋ฌ๋ ค์ ๋ณด์๋๋ฐ์, self-training ์ ํตํ noisy student ๊ฐ pseudo label ์ ํตํด ํ์ต์ ํ๋๋ฐ ๋์ฒด ์ด pseudo labeling์ ์ด๋ป๊ฒ ์งํ๋๋ ๊ฒ์ธ์ง.. ์ ๋๋ก ๋ ์ด๋ธ์ด ๋ ๋ฐ์ดํฐ๊ฐ ์ถ๊ฐ๋๋๊ฒ ๋ง๋์ง ์๋ฌธํฌ์ฑ์ด์๋๋ฐ ์ด ๋ ผ๋ฌธ์ ๋ณด๋ ์ดํด๊ฐ ๋์๋ค์๐
์ด์ํ๊ฒ pseudo label ๋ ผ๋ฌธ์ ๋ฆฌ๋ทฐํ ๋ธ๋ก๊ทธ๊ฐ ๊ฑฐ์ ์์ด์ ๋ ผ๋ฌธ์ ํ๊ณ ๋ค์ด ๊ณต๋ถํ๋๋ผ ํ์ด ๋ค์์ง๋ง..
์ฌ๊ธฐ๊น์ง ํ๋ณธ ๋์๊ฒ ๋ฐ์๐ญ๐ญ
๋ค์ ํฌ์คํ ์์ ๋ง๋์๐ฑ๐ฑ
๋๊ธ๋จ๊ธฐ๊ธฐ