[๋ ผ๋ฌธ์ ๋ฆฌ๐] An Image is Worth 16X16 Words : Transformers for Image Recognition at Scale
An Image is Worth 16X16 Words : Transformers for Image Recognition at Scale
-ViT-
๋ ผ๋ฌธ์๋ณธ๐
์ค๋์ Vision Transformer์ ๋ํด์ ๋ฆฌ๋ทฐํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค. Vision Transformer, ์ค์ฌ์ ViT๋ ์๋ ์์ฐ์ด ์ฒ๋ฆฌ(NLP)์์ ์ฌ์ฉํ๋ ๋ชจ๋ธ์ Vision ์ชฝ์ ์ ์ฉํ ๋ชจ๋ธ์ ๋๋ค.
transformer
๊ณผ attention
์ ๋ํ ๊ฐ๋
์ด ๋์ค๋๋ฐ, [๋
ผ๋ฌธ์ ๋ฆฌ๐]Attention is all you need ๋ฅผ ๋จผ์ ์ฝ์ด๋ณด๋ฉด Vision transformer์ ๋ํด ์ดํดํ๊ธฐ ์์ ํ ๊ฒ์
๋๋ค. :)
Attention is all you need ๋ ผ๋ฌธ ๋งํฌ๋ ๊ฑธ์ด๋๊ฒ ์ต๋๋ค!
๐ Abstract
NLP ์์ 100B ์ ํ๋ผ๋ฏธํฐ๊ฐ ๋๋ ์ฌ์ด์ฆ๋ฅผ ํ์ตํ ์ ์๊ฒ ๋ ๋ฐ๋ฉด, ์ปดํจํฐ ๋น์ ์ชฝ์์๋ ์์ง CNN ๊ตฌ์กฐ๊ฐ ์ฃผ๋ฅผ ์ด๋ฃจ๊ณ ์์ต๋๋ค. ์ด๋ฐ NLP์ ์ฑ๊ณต์ ์ํฅ์ ๋ฐ์ ์ด๋ฏธ์ง์ ์ง์ ์ ์ผ๋ก Transformer์ ์ ์ฉํ๊ณ ์ ํ์๊ณ , transformer์ ์ ์ฉํ์ฌ ViT๋ ์์ฒญ๋ ํฌ๊ธฐ์ ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ตํ์ฌ SOTA๋ฅผ ๋ฌ์ฑํ๋๋ฐ ์ฑ๊ณตํฉ๋๋ค.
ํ๋ก์ธ์ค๋ฅผ ๊ฐ๋จํ ์ดํด๋ณด๋ฉด,
์์ฐ์ด์์ ์ ๋ ฅ ์ํ์ค๋ฅผ ํ ํฐ์ผ๋ก ์ชผ๊ฐ์ด ๊ฐ๊ฐ์ ์ ๋ ฅ๊ฐ์ผ๋ก ํด์ ์งํํ๋ ๊ฒ ์ฒ๋ผ vision ์ชฝ์์๋ ์ด๋ฏธ์ง๋ฅผ ํจ์น ๋จ์๋ก ์ชผ๊ฐ์ด์ ๊ฐ๋ณ์ ์ผ๋ก ์ ๋ ฅ์ ๋ฃ์ด ์ฒ๋ฆฌํด ์งํํ๊ฒ ๋๋๋ฐ
ํํธ, transformer์ ์ปดํจํฐ ๋น์ ์ ์ ์ฉํ๊ธฐ์๋ inductive bias
๋ฅผ ๊ฐ๊ณ ์์ต๋๋ค. inductive bias๋ ํ์ต์์ ์์๋ ์๋ก์ด ๋ฐ์ดํฐ๊ฐ ์
๋ ฅ์ผ๋ก ๋ค์ด์ฌ ๋, ํด๋น ๋ฐ์ดํฐ์ ๋ํ ํ๋จ์ ๋ด๋ฆฌ๊ธฐ ์ํด ํ์ต๊ณผ์ ์์ ์ต๋๋ Bias๋ฅผ ๊ฐ์ง๊ณ ํ๋จ์ ํ๊ฒ ๋ฉ๋๋ค. ์ด๊ฒ์ inductive bias
๋ผ๊ณ ํด์ํ๋ฉด ๋ ๊ฒ ๊ฐ์ต๋๋ค.
CNN์ locality ํ ์ ๋ณด๋ฅผ ์ถ์ถํ๊ณ convolution์ฐ์ฐ์ translation equivariance์ localityํ๋ค๋ ํน์ง์ ๊ฐ๊ณ ์์ต๋๋ค.
๋ฐ๋ผ์ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ๊ฐ๊ณ ์๋ CNN์ ์๋ก์ด ๋ถํฌ์ ๋ฐ์ดํฐ๋ฅผ ์ ๋ ฅ ๋ฐ์ ๋, ํด๋น ๋ฐ์ดํฐ์ ๋ํ ํ๋จ์ ๋ด๋ฆฌ๊ธฐ ์ํด ํ์ต๊ณผ์ ์์ ์ต๋๋ Bias์ธ โ์ง์ญ์ ํน์ฑโ ๊ฐ์ง๊ณ ํ๋จ์ ํ๊ฒ ๋๋ ๊ฒ์ ๋๋ค.
equivariance ๋, ํจ์ ์ ๋ ฅ์ด ๋ฐ๋๋ฉด ์ถ๋ ฅ ๋ํ ๋ฐ๋๋ค๋ ๋ป์ด๋ฉฐ, translation equivariance๋ ์ ๋ ฅ ์์น๊ฐ ๋ณํ๋ฉด ์ถ๋ ฅ๋ ๋ง์ฐฌ๊ฐ์ง๋ก ์์น๊ฐ ๋ณํ์ฑ๋ก ๋์จ๋ค๋ ๋ป์ ๋๋ค.
locality ๋, ์ด๋ฏธ์ง๋ฅผ ๊ตฌ์ฑํ๋ ํน์ง๋ค์ ์ด๋ฏธ์ง ์ ์ฒด๊ฐ ์๋ ์ผ๋ถ ์ง์ญ์ ๊ทผ์ ํ ํฝ์ ๋ค๋ก๋ง ๊ตฌ์ฑ๋๊ณ , ๊ทผ์ ํ ํฝ์ ๋ค๋ผ๋ฆฌ๋ง ์ข ์์ฑ์ ๊ฐ์ง๋ค๋ ์ฑ์ง์ ๋๋ค.
๊ฒฐ๊ตญ, transformer์ ์ด๋ฏธ์ง๋ฅผ ํจ์น๋ก ์ชผ๊ฐ๊ณ ์ผ๋ ฌ๋ก ์ธ์์ ํ์ต(์ง์ญ์ ํน์ง ๋ฌด์)ํ๋๋ฐ, ๊ฐ ํฝ์ ๋ค์ ์์น์ ์ฃผ๋ณ ๊ฐ๋ค์ด ์ค์ํ CNN์ ์์ด์, ์ด ํจ์น๋ค์ ์ ๋ ฅ์ด 1์ฐจ์์ ๊ฐ์ผ๋ก ๋ณ๋ ฌ ์ฒ๋ฆฌ๊ฐ ๋๋ค๋ฉด ์ํ๋ ์ถ๋ ฅ์ด ๋์ฌ ์ ์๊ฒ ๋๋ค๋ ๋ป์ธ ๊ฒ ๊ฐ์ต๋๋ค.
ํ์ง๋ง!, ์ด ๋
ผ๋ฌธ์์๋ ์์ฒญ๋ ์์ ๋ฐ์ดํฐ
๋ฅผ ์ฌ์ฉํด ์ด inductive bias๋ฅผ ์์จ ์ ์์๋ค๊ณ ๋งํฉ๋๋ค. ๊ฒฐ๊ตญ.. ๋ฐ์ดํฐ๊ฐ ๋ง์๊ฒ ํ์ ์๋๋ด
๋๋ค.๐ฑ
๐ Method
์์ ์ด๋ฏธ์ง์ ๊ฐ์ด ์ด๋ฏธ์ง๋ฅผ ์ผ์ ํ ๊ฐ์ patch๋ก ๋ถํ ํ๊ณ ์ด patch ๋ค์ embedding์ ์์ผ์ transformer์ ์ ๋ ฅ๊ฐ์ผ๋ก ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค. ๊ฐ๋ณ์ patch๋ฅผ NLP์ token ์ฒ๋ผ ๊ฐ์ฃผํ๋ ๊ฒ์ด์ฃ .
ํ๋ก์ธ์ค๋ฅผ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
๋จผ์ 48x48์ ์ด๋ฏธ์ง๋ฅผ input image๋ก ํ๋ค๋ฉด, (16x16) * 3 ์ ํํ๋ก ์ด 9๊ฐ์ patch ํํ๋ก ์ชผ๊ฐญ๋๋ค. ๊ทธ๋ฆฌ๊ณ ๊ฐ๊ฐ์ patch๋ค์ linear ํ๊ฒ prjection ์์ผ์ patch embedding์ ์งํํฉ๋๋ค. ๋ฐ๋ผ์ 768์ฐจ์์ ํ๋์ ๋ฒกํฐ๊ฐ ๋๋ ๊ฒ์ ๋๋ค.
16x16x3 = 768 = D
๊ทธ ๋ค์์๋ ๊ฐ๊ฐ์ embedding๋ ํจ์น์ PE(position embedding) ๋ฅผ ๋ํด์ค ๊ฐ์ ์ต์ข
transformer์ ์
๋ ฅ ๋ฒกํฐ๋ก ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค.
์ด๋, ๋งจ ์์๋ CLASS TOKEN ์ด๋ผ๋ ๊ฒ์ด ๋ถ๋๋ฐ์, ์์ฐ์ด ์ฒ๋ฆฌ์ BERT ๋ชจ๋ธ์์ ๋์จ ๊ฐ๋
์ ์ฌ์ฉํ ๊ฒ์
๋๋ค. class token์ ํ์ต์ ํ๋ฉด์ ๊ฐ ํ ํฐ๋ค์ ๋ํ ์ ๋ณด๊ฐ์ ์ ์ฅํ๊ฒ ๋ฉ๋๋ค.
๋ฐ๋ผ์ ์ด 0๋ฒ์งธ ํจ์น๋ ๊ฐ ํจ์น๋ค์ ์์น๊ฐ๊ณผ ์ ๋ณด๋ค์ ๊ฐ์ง๊ณ ์์ด ์ค์ ๋ก classification์ผ๋ก ์๋ณ ํ ๋, ์ด class token์ผ๋ก ์งํ
ํ๋ค๊ณ ํฉ๋๋ค. ๋๋จธ์ง ํจ์น๋ค์ ํ์ต์ฉ์ผ๋ก, ์ ๋ณด๋ฅผ ๋น๊ตํ ๋๋ ๋งจ ์์ ํจ์น(class token)์ผ๋ก!
์ฌ๊ธฐ์ ์ ๊น!๐ ๊ธฐ์กด์ transformer ๊ณผ vision์์ ์ฌ์ฉํ๋ transformer์ ์ฐจ์ด์ ๋ํด ์ง๊ณ ๋์ด๊ฐ๊ฒ ์ต๋๋ค.
๋์ ๋น๊ตํ์ ๋! Normalization ์ ์์น๊ฐ ๋ค๋ฅธ๊ฒ์ ํ์ธํ ์ ์์ต๋๋ค. Layer normalization์ ์์น๊ฐ ์ค์ํ๋ค๋ ๊ฒ์ด ํ์ ์ฐ๊ตฌ๋ค์์ ๋ฐํ์ง๊ฒ ๋๋ฉด์ Attention๊ณผ MLP์ด์ ์ ์ ๊ทํ๋ฅผ ์ํํ์ฌ ๋ ๊น์ ๋ ์ด์ด์์๋ ํ์ต์ด ์ ๋๋๋ก ํ๋ค๊ณ ํฉ๋๋ค.
๐ Fine-tuning and higher resolution
ViT๋ ๊ทธ๋๋ก ์ฌ์ฉํ๊ธฐ ๋ณด๋ค, ๋จผ์ ํฐ ๋ฐ์ดํฐ์ ์ผ๋ก pre-train ์ ์ํจ ๋ชจ๋ธ์ fine-tune ํ์ฌ ์ฌ์ฉํ๋ค๊ณ ํฉ๋๋ค. ๋ฐ๋ผ์, transformer encoder ์ถ๋ ฅ๊ฐ์ ํ๋์ hidden layer๋ฅผ ๊ฐ์ง MLP๋ฅผ ์ฌ์ฉํ์ฌ pre-train ํ๊ฒ ๋ฉ๋๋ค. fine-tunning ์์๋ ๋๋ค ์ด๊ธฐํ๋ ํ๋์ linear layer๋ฅผ ์ฌ์ฉํ๋ฉฐ pre-train์์ ํ์ต๋ positional embedding์ fine-tunning์์ ๋ค๋ฅธ ํด์๋๋ก ์ด๋ฏธ์ง์ ๋ํ position embedding์ ์กฐ์ ํด์ผ ํฉ๋๋ค.
ํ์ต ์ด๊ธฐ์ position embedding์ patch์ 2D ์์น์ ๋ํ ์๋ฌด๋ฐ ์ ๋ณด๋ฅผ ์ ๊ณตํ์ง ์๊ณ , patch ์ฌ์ด์ ๊ณต๊ฐ์ ์ธ ๊ด๊ณ๋ ์ฒ์๋ถํฐ ํ์ต๋์ด์ผ ํฉ๋๋ค. ๋ ผ๋ฌธ์์๋ ๊ณ ํด์๋ ์ด๋ฏธ์ง๋ก fine tunning ํ ๋๋, ํ์ต๋ pre-trained position embedding์ 2D interpolationํ๋ ๋ฐฉ๋ฒ(์๋์ผ๋ก bias๋ฅผ ๋ฃ๋ ๋ฐฉ๋ฒ)์ ์ฌ์ฉํฉ๋๋ค.
๐ hybrid architecture
CNN๊ณผ transformer๋ฅผ ํจ๊ป ์ฌ์ฉํ๋ hybrid archtecture๋ ์๊ฐํ๊ณ ์์ต๋๋ค. CNN์ feature map์ผ๋ก๋ถํฐ patch๋ฅผ ์ถ์ถํ์ฌ patch embedding์ ์ ์ฉํ๋ ํ๋ก์ธ์ค๋ฅผ ๊ฑฐ์น๊ฒ ๋๋๋ฐ, CNN์ผ๋ก localtyํ ์ ๋ณด๋ฅผ ์ถ์ถํ๊ณ , ์ด ์ ๋ณด๋ค์ patch์ sequence๋ฅผ embeddingํ์ฌ transformer๋ก ์ ๋ฌํ๋ ๊ฒ์ ๋๋ค. VIT์ ๋นํด ์ข ๋ ์ ์ dataset์ผ๋ก ์ฑ๋ฅ์ด saturateํ๋ค๋ ์ฅ์ ์ด ์์ต๋๋ค. ํ์ง๋ง ํฐ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ๋ค๋ฉด VIT๋ณด๋ค ์ฑ๋ฅ์ด ๋ค๋จ์ด ์ง๋๋ค.
์๋๋ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ฃผ๋ ๊ทธ๋ํ ์ ๋๋ค.
๐ Experiments
์๋๋ ๋ค์ํ ํฌ๊ธฐ์ ๋ฐ์ดํฐ์ ์ ๋ํด pre-train์ ์งํํ๊ณ SOTA ์ฑ๋ฅ์ ๋ณด์ด๋ ๊ฒ์ ํ์ธ์์ผ์ฃผ๊ณ ์์ต๋๋ค. JFT ๋ฐ์ดํฐ์ ์ ๋ํ ์ ๋ณด๋ ์ฌ๊ธฐ!
ImageNet Top-1 accuracy ์ฑ๋ฅ์ ๋ณด์๋ pretraining dataset์ ํฌ๊ธฐ๊ฐ ํด ์๋ก ์ฑ๋ฅ์ด ๋ ์ข์์ง๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค.
๋ํ, JFT ๋ฐ์ดํฐ์ ํฌ๊ธฐ๋ฅผ ๋ค๋ฅด๊ฒ ์ฃผ์ด ํ์ตํ ๊ฒฐ๊ณผ ์ญ์๋ ๋ฐ์ดํฐ์ ์ด ํด ์๋ก ์ฑ๋ฅ์ด ์ข์์ง๋๋ค.
์ฐธ๊ณ
[1] inductive bias
๋๊ธ๋จ๊ธฐ๊ธฐ