[๋ ผ๋ฌธ์ ๋ฆฌ๐] Self-training with Noisy Student improves ImageNet classification
Self-training with Noisy Student improves ImageNet classification
- Noisy Student -
๋ ผ๋ฌธ์๋ณธ๐
์ด๋ฒ ๋ ผ๋ฌธ์ EfficientNet ์์ ๋ ๋ง์ ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ต์์ผ ์๋ก์ด SOTA ์ฑ๋ฅ์ ๋ณด์ด๋ Noisy student์ ๋ํด ๋ฆฌ๋ทฐํ๋๋ก ํ๊ฒ ์ต๋๋ค.
๐ Introduction
๊ทธ ๋์ ๋ฅ๋ฌ๋์ ์ด๋ฏธ์ง ์ธ์์ ์์ด์ ๋๋ถ์ ์ฑ๊ณต์ ๋ณด์์ต๋๋ค. ํ์ง๋ง SOTA(state-of-the-art) ์ฑ๋ฅ์ ๋ด๊ธฐ ์ํด์ ์์ฃผ ํฐ ๋ฐ์ดํฐ์ ์ด ํ์ํ๋ค๋ ํ๊ณ์ ๋ถ๋ชํ์ต๋๋ค. ์ด ๋ฌธ์ ๋ฅผ ๋ณด์ํ๊ณ ์ ๋์จ๊ฒ์ด ๋ ์ด๋ธ ๋์ง ์์ ๋ฐ์ดํฐ์ (unlabeled-dataset)์ ์ถ๊ฐํด semi-supervised learning ์ ์ฌ์ฉํ๋ ๋ฐฉ์์ ๋๋ค. ๋ฒ ์ด์ค ๋ชจ๋ธ์ EfficientNet ๋ชจ๋ธ์ ์ฌ์ฉํ์์ผ๋ฉฐ, ๊ฐ Efficient-B0~B7 ๋ชจ๋ธ๋ก ํ์ตํ ๊ฒฐ๊ณผ ์ฑ๋ฅ์ด ๋ ํฅ์๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
๋ ์ด๋ธ๋์ง ์์ ๋ฐ์ดํฐ์ ์ JFT-300M ๋ฐ์ดํฐ ์ ์ ์ฌ์ฉํ์์ผ๋ฉฐ ๊ตฌ๊ธ ๋ด์์ ๋ง๋ ๋ฐ์ดํฐ์ ์ด๋ผ๊ณ ํฉ๋๋ค. ์ฝ 3์ต๊ฐ์ ์ด๋ฏธ์ง๊ฐ ์์ด ์ฝ 125๋ง๊ฐ ์ ๋๋๋ imageNet ๋ฐ์ดํฐ์ ์ ๋นํ๋ฉด ์์ฒญ๋ ํฌ๊ธฐ์ ๋๋ค. JFT-300M ๋ฐ์ดํฐ์ ์ ๋ํ ์ค๋ช ์ JFT-300M ๋ฐ์ดํฐ์ ์ด๋?ํฌ์คํ ํ ๊ธ์ ์ฐธ๊ณ ํ์ธ์!๐
ํ์ต ๋ฐฉ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
๋จผ์ , (1) ์ ์๋ ๋ชจ๋ธ์ ๋ ์ด๋ธ๋ ๋ฐ์ดํฐ์ (imageNet dataset)์ผ๋ก ํ์ต์ํจ ํ (2) ํ์ต๋ ์ ์๋ ๋ชจ๋ธ๋ก ๋ ์ด๋ธ๋์ง ์์ ๋ฐ์ดํฐ์ (JFT-300M)์ ์๋ ๋ ์ด๋ธ(pseudo-label)์ ์์ฑํด๋ ๋๋ค. ๊ทธ ํ (3) ํ์ ๋ชจ๋ธ์ด (1)๊ณผ (2)๊ฐ ํฉ์ณ์ง ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ต์ ํ๊ฒ ๋ฉ๋๋ค. ๊ทธ๋ ๊ฒ ๋๋ฉด ์ ์๋ ๋ชจ๋ธ๊ณผ ๊ฐ๊ฑฐ๋ ๋ ํฐ ๋ชจ๋ธ์ด ์์ฑ๋ฉ๋๋ค. ์ดํ์ (4) step(2)์ (3)์ ๋ฐ๋ณตํ๋ฉด์ ๋ชจ๋ธ ์ฑ๋ฅ์ ๋ ๋์ด๋๋ก ํ์ตํฉ๋๋ค.
๋ ผ๋ฌธ์์ ์ ์ํ ์ด๋ฏธ์ง๋ ์๋์ ๊ฐ์ต๋๋ค.
์ ๋ฆฌํ์๋ฉด,
Noisy student Training process
๋จ, ์ฌ๊ธฐ์ ์๋ฌธ์ธ ๊ฒ์ pseudo label์ด ์ด๊ธฐ์๋ ๋ถ๋ช ํํ ํ ๋ฐ ์ด๋ป๊ฒ ํด์ ๋ฐ๋ณตํ๋ฉด์ ํ์ตํ๋ฉด ์ฑ๋ฅ์ด ์ข์์ง๋์ง ๊ทธ ๋ถ๋ถ์ ๋ํ ์ค๋ช ์ด ์ข ๋ถ์กฑํ๋ค. soft labeled ๋ฅผ ์ฌ์ฉํ์๋ค๊ณ ์งง๊ฒ ํ์ค ์ค๋ช ์ด ์๋๋ฐ, ์ด ๋ถ๋ถ์ ๋ํด์๋ ์ด ๋ ผ๋ฌธ์ ์ฒซ๋ฒ์งธ ๋ ํผ๋ฐ์ค์ธ Pseudo-Labeling and Confirmation Bias in Deep Semi-Supervised Learning์ ์ฐธ๊ณ ํ๋ฉด ๋์์ด ๋๋ค.
๐ ์ด pseudo label ๋ ผ๋ฌธ์ ๋ํด์ ์ ๋ฆฌํ ๊ธ์ [๋ ผ๋ฌธ ์ ๋ฆฌ]Pseudo-Labeling and Confirmation Bias in Deep Semi-Supervised Learning์ ์ ๋ฆฌํด ๋์์ผ๋ ํ์ํ๋ค๋ฉด ์ฐธ๊ณ !โ๏ธ
์ด๋ ๊ฒ ๋ ์ด๋ธ ๋์ง์์ JFT-300M ๋ฐ์ดํฐ์
์ ์ถ๊ฐํด imageNet ์ ๋ํ SOTA ์ฑ๋ฅ์ ํฅ์์์ผฐ์ผ๋ฉฐ ๋์์ ๊ฒฌ๊ณ ํ ๋ชจ๋ธ
์ด ์์ฑ๋์์ต๋๋ค. ์ฌ๊ธฐ์ ๊ฒฌ๊ณ ํ๋ค๋ ๊ฒ์ ๋
ผ๋ฌธ์์ โrobustnessโ๋ผ๊ณ ํํํ์๋๋ฐ,
robustnessํ ๋ชจ๋ธ์ ๋ง๋ค์๋ค๋ ๊ฒ์ ๋ชจ๋ธ์ ๋ค์ด์ค๋ ๋ฐ์ดํฐ๊ฐ ์ฒ์๋ณด๋ ๋ฐ์ดํฐ์ด๊ฑฐ๋ ์๋ ํ์ตํ์๋ ๋ฒ์ฃผ์ ์ข ๋๋จ์ด์ง ๋ฒ์ฃผ์ ์ด๋ฏธ์ง๋ ์ ๋ถ๋ฅํ ์ ์๋ค๋ ๋ป์
๋๋ค.
์๋๋ noisy student๋ฐฉ์์ผ๋ก ํ์ตํ ๋ชจ๋ธ์ด ImageNet dataset๋ค์ SOTA ์ฑ๋ฅ์ ๋ณด์ด๋ ๊ฒ์ ๋ํ๋ด๋ ์งํ์ ๋๋ค.
ํ์ ๋งจ ์ผ์ชฝ์ ImageNet ๋ฐ์ดํฐ์ ์ด๋ฉฐ ์ฐจ๋ก๋๋ก ๋ฐ์ดํฐ์ ๊ณผ ์ฑ๋ฅ์งํ์ ๋ํ ์ค๋ช ์ ํ์๋ฉด,
ImageNet-A : ๊ตฌ๋ถํ๊ธฐ ์ด๋ ค์ด 200 classes์ ์ด๋ฏธ์ง๋ค๋ก ๊ตฌ์ฑ๋ dataset
ImageNet-C : 15๊ฐ์ง corruption์ผ๋ก ์คํ๋ ๋ฐ์ดํฐ์ ์ด๋ฉฐ ๋ณํ๊ฐ ์์ฃผํฐ ๋ฐ์ดํฐ์ ์ด๋ค. mCE(mean corruption error)๋ฅผ ์ฑ๋ฅ์งํ๋ก ์ฌ์ฉํ๊ณ ์๋ค.
ImageNet-P : ImageNet-C ์ ๋นํ๋ฉด ์ฝ๊ฐ์ ๋ ธ์ด์ฆ๋ ๊ธฐ์ธ์ด์ง ์ ๋๋ก augmentation๋ ๋ฐ์ดํฐ์ ์ด๋ค. ์ ์ ๋ณํ์ ๋ฐ์ดํฐ์ ์ด๋ผ๊ณ ์๊ฐํ๋ฉด ๋ ๊ฒ ๊ฐ๋ค. mFR(mean flip rate)์ ์ฑ๋ฅ์งํ๋ก ์ฌ์ฉํ๊ณ ์๋ค.
mCE(mean corruption error) : image noise์ ์ผ๋ง๋ ๊ฐํ ๋ชจ๋ธ์ธ์ง๋ฅผ ๋ณด๋ ์งํ๋ก noise๊ฐ ์๋ ๋ฐ์ดํฐ๋ ์ ์์ธกํด์ผํ๋ฏ๋ก ์์น๊ฐ ์์ ์๋ก ์ข์ ์ฑ๋ฅ์ ๋ํ๋ธ๋ค.
mFR(mean flip rate) : perturbation(์์ ๋ณํ)์ด ๋ฐ๋ ๋ top-1 prediction์ด ๋ฐ๋ ํ๋ฅ ์ ๋ํ๋ด๋ฉฐ ๋ณํ๊ฐ ์ผ์ด๋ ๋ฐ์ดํฐ์ ์์ธก์ ๋ณํํ ํ๋ฅ ์ด ์ ์ด์ผ ์ข์ ์ฑ๋ฅ์ด๋ฏ๋ก ์์น๊ฐ ์ ์ ์๋ก ์ข์ ์ฑ๋ฅ์ ๋ํ๋ธ๋ค.
๐ iterative training ์ ํตํด ํ์ตํ ์ต๊ณ ์ฑ๋ฅ ๋ชจ๋ธ
iterative training ์ ํตํด ํ์ตํ ์ต๊ณ ์ฑ๋ฅ ๋ชจ๋ธ์ ๋ค์๊ณผ ๊ฐ๋ค. ์ฒ์ teacher๊ณผ student ๋ชจ๋ธ์ EfficientNet-B7์ผ๋ก ํ์ต์ํจ ํ ๊ทธ ์ดํ๋ก๋ EfficientNet-L2๋ก 3๋ฒ์ ๋ฐ๋ณต์ ํตํด ์ป์ ๋ชจ๋ธ์ด SOTA์ฑ๋ฅ์ ๋ณด์ธ๋ค๊ณ ํ๋ค. ์ฌ๊ธฐ์ EfficientNet-L2 ๋ efficientNet์ ํ๋ผ๋ฏธํฐ๊ฐ์ scaling์ ํฌ๊ฒ ํ์ฌ ์์ฑํ ๋ชจ๋ธ์ด๋ค. ์์ธํ scaling ๋ฐฉ๋ฒ์ ๋ํด์๋ EfficientNet ๋ ผ๋ฌธ๋ฆฌ๋ทฐ์์ ํ์ธํ ์ ์๋ค.
๋ณด๋ฉด label ๋ฐ์ดํฐ์ ๊ณผ unlabel ๋ฐ์ดํฐ์ ์ batch ๋น์ค์ ๋ค๋ฅด๊ฒ ํด์ฃผ์๋๋ฐ unlabel ๋ฐ์ดํฐ์ ์ ๋ ๋ง์ด ํ์ตํ๋๋ก ํด์ฃผ์๊ณ ๊ทธ ๊ฒฐ๊ณผ ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ธ๋ค๊ณ ํ๋ค.
๐ noise ๊ฐ ์ค์ํ ์ด์
training์ ํ ๋, student ๋ชจ๋ธ์ ์ต์ข ๋ฐ์ดํฐ์ ์ noise๋ฅผ ์ค (augmentation์ ๊ฐํ) ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ต์ ํ๋๋ฐ, ์ด๋์ noisesms SD(stochastic depth), dropout, data-augmentation ์ด ์ธ ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
์๋ ํ๋ฅผ ๋ถ์ํด๋ณด๋ฉด ๊ธฐ๋ณธ EfficientNet-B5 ๋ชจ๋ธ(๋ ธ๋์)๋ณด๋ค Noisy student training์ ํ ๋ชจ๋ธ(๋นจ๊ฐ์)์ ์ฑ๋ฅ์ด ๋ ๋์๊ฒ์ ํ์ธํ ์ ์๋ค. ๋ํ, ๊ทธ์ ๋นํด augmentation์ ํ์ง ์์ ๋ชจ๋ธ๋ค์(๋นจ๊ฐ์) ์ฑ๋ฅ์ด ๋ ๋ฎ์์ง์ ๋ณผ ์ ์๋ค.
๊ทธ๋ฆฌ๊ณ ํํธ, teacher ๋ชจ๋ธ์ augmentation์ ์ฃผ๋ฉด ๋ฐ๋๋ก ๋ ์ฑ๋ฅ์ด ๋ฎ์์ง๋ ๊ฒ์ ์ ์ ์๋ค. ํ์ง๋ง ์ ์ฒด์ ์ผ๋ก noisy student training ์ ๊ฑฐ์น ๋ชจ๋ธ๋ค์ augmentation ์ฌ๋ถ์ ์๊ด์์ด ์๋ ๋ฒ ์ด์ค efficientNet ๋ชจ๋ธ๋ณด๋ค ๋ ์ข์ ์ฑ๋ฅ์ ๋ณด์ด๋๋ฐ ์ด๊ฒ์ ๋ ผ๋ฌธ์์๋ ์๋ง training ๋ stochastic gradient descent ๋ฅผ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์ ๋ ๋์์ง ๊ฒ์ด๋ผ๊ณ ๊ฐ์ค์ ์ธ์ ๋ค.
๐ Experiments Result
์ด๋ฒ ๋ ผ๋ฌธ์ ์ด์ SOTA ์ฑ๋ฅ ๋ชจ๋ธ๋ณด๋ค ๋ ์ ์ ํ๋ผ๋ฏธํฐ์ extra data ์๋ก ๋ ๋์ ์ฑ๋ฅ์ ๋ฌ์ฑํ๋ค๋ ์ ์์ ์๋ฏธ๊ฐ ์๋ค.
๋๊ธ๋จ๊ธฐ๊ธฐ