[๋…ผ๋ฌธ์ •๋ฆฌ๐Ÿ“ƒ] An Image is Worth 16X16 Words : Transformers for Image Recognition at Scale

An Image is Worth 16X16 Words : Transformers for Image Recognition at Scale

-ViT-

๋…ผ๋ฌธ์›๋ณธ๐Ÿ˜™

์˜ค๋Š˜์€ Vision Transformer์— ๋Œ€ํ•ด์„œ ๋ฆฌ๋ทฐํ•ด๋ณด๋„๋ก ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. Vision Transformer, ์ค„์—ฌ์„œ ViT๋Š” ์›๋ž˜ ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ(NLP)์—์„œ ์‚ฌ์šฉํ•˜๋˜ ๋ชจ๋ธ์„ Vision ์ชฝ์— ์ ์šฉํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

transformer ๊ณผ attention ์— ๋Œ€ํ•œ ๊ฐœ๋…์ด ๋‚˜์˜ค๋Š”๋ฐ, [๋…ผ๋ฌธ์ •๋ฆฌ๐Ÿ“ƒ]Attention is all you need ๋ฅผ ๋จผ์ € ์ฝ์–ด๋ณด๋ฉด Vision transformer์— ๋Œ€ํ•ด ์ดํ•ดํ•˜๊ธฐ ์ˆ˜์›” ํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค. :)

Attention is all you need ๋…ผ๋ฌธ ๋งํฌ๋„ ๊ฑธ์–ด๋‘๊ฒ ์Šต๋‹ˆ๋‹ค!

๐ŸŒ• Abstract

NLP ์—์„œ 100B ์˜ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋„˜๋Š” ์‚ฌ์ด์ฆˆ๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋œ ๋ฐ˜๋ฉด, ์ปดํ“จํ„ฐ ๋น„์ „์ชฝ์—์„œ๋Š” ์•„์ง CNN ๊ตฌ์กฐ๊ฐ€ ์ฃผ๋ฅผ ์ด๋ฃจ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฐ NLP์˜ ์„ฑ๊ณต์— ์˜ํ–ฅ์„ ๋ฐ›์•„ ์ด๋ฏธ์ง€์— ์ง์ ‘์ ์œผ๋กœ Transformer์„ ์ ์šฉํ•˜๊ณ ์ž ํ•˜์˜€๊ณ , transformer์„ ์ ์šฉํ•˜์—ฌ ViT๋Š” ์—„์ฒญ๋‚œ ํฌ๊ธฐ์˜ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ•™์Šตํ•˜์—ฌ SOTA๋ฅผ ๋‹ฌ์„ฑํ•˜๋Š”๋ฐ ์„ฑ๊ณตํ•ฉ๋‹ˆ๋‹ค.

ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฐ„๋‹จํžˆ ์‚ดํŽด๋ณด๋ฉด,

์ž์—ฐ์–ด์—์„œ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ํ† ํฐ์œผ๋กœ ์ชผ๊ฐœ์–ด ๊ฐ๊ฐ์„ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ํ•ด์„œ ์ง„ํ–‰ํ•˜๋Š” ๊ฒƒ ์ฒ˜๋Ÿผ vision ์ชฝ์—์„œ๋„ ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜ ๋‹จ์œ„๋กœ ์ชผ๊ฐœ์–ด์„œ ๊ฐœ๋ณ„์ ์œผ๋กœ ์ž…๋ ฅ์„ ๋„ฃ์–ด ์ฒ˜๋ฆฌํ•ด ์ง„ํ–‰ํ•˜๊ฒŒ ๋˜๋Š”๋ฐ

ํ•œํŽธ, transformer์„ ์ปดํ“จํ„ฐ ๋น„์ „์— ์ ์šฉํ•˜๊ธฐ์—๋Š” inductive bias๋ฅผ ๊ฐ–๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. inductive bias๋Š” ํ•™์Šต์—์„  ์—†์—ˆ๋˜ ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ๊ฐ€ ์ž…๋ ฅ์œผ๋กœ ๋“ค์–ด์˜ฌ ๋•Œ, ํ•ด๋‹น ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ํŒ๋‹จ์„ ๋‚ด๋ฆฌ๊ธฐ ์œ„ํ•ด ํ•™์Šต๊ณผ์ •์—์„œ ์Šต๋“๋œ Bias๋ฅผ ๊ฐ€์ง€๊ณ  ํŒ๋‹จ์„ ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์„ inductive bias๋ผ๊ณ  ํ•ด์„ํ•˜๋ฉด ๋  ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

CNN์€ locality ํ•œ ์ •๋ณด๋ฅผ ์ถ”์ถœํ•˜๊ณ  convolution์—ฐ์‚ฐ์€ translation equivariance์™€ localityํ•˜๋‹ค๋Š” ํŠน์ง•์„ ๊ฐ–๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.

๋”ฐ๋ผ์„œ ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ๊ฐ–๊ณ  ์žˆ๋Š” CNN์€ ์ƒˆ๋กœ์šด ๋ถ„ํฌ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์ž…๋ ฅ ๋ฐ›์„ ๋•Œ, ํ•ด๋‹น ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ํŒ๋‹จ์„ ๋‚ด๋ฆฌ๊ธฐ ์œ„ํ•ด ํ•™์Šต๊ณผ์ •์—์„œ ์Šต๋“๋œ Bias์ธ โ€œ์ง€์—ญ์  ํŠน์„ฑโ€ ๊ฐ€์ง€๊ณ  ํŒ๋‹จ์„ ํ•˜๊ฒŒ ๋˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

equivariance ๋ž€, ํ•จ์ˆ˜ ์ž…๋ ฅ์ด ๋ฐ”๋€Œ๋ฉด ์ถœ๋ ฅ ๋˜ํ•œ ๋ฐ”๋€๋‹ค๋Š” ๋œป์ด๋ฉฐ, translation equivariance๋Š” ์ž…๋ ฅ ์œ„์น˜๊ฐ€ ๋ณ€ํ•˜๋ฉด ์ถœ๋ ฅ๋„ ๋งˆ์ฐฌ๊ฐ€์ง€๋กœ ์œ„์น˜๊ฐ€ ๋ณ€ํ•œ์ฑ„๋กœ ๋‚˜์˜จ๋‹ค๋Š” ๋œป์ž…๋‹ˆ๋‹ค.

locality ๋ž€, ์ด๋ฏธ์ง€๋ฅผ ๊ตฌ์„ฑํ•˜๋Š” ํŠน์ง•๋“ค์€ ์ด๋ฏธ์ง€ ์ „์ฒด๊ฐ€ ์•„๋‹Œ ์ผ๋ถ€ ์ง€์—ญ์— ๊ทผ์ ‘ํ•œ ํ”ฝ์…€๋“ค๋กœ๋งŒ ๊ตฌ์„ฑ๋˜๊ณ , ๊ทผ์ ‘ํ•œ ํ”ฝ์…€๋“ค๋ผ๋ฆฌ๋งŒ ์ข…์†์„ฑ์„ ๊ฐ€์ง„๋‹ค๋Š” ์„ฑ์งˆ์ž…๋‹ˆ๋‹ค.

๊ฒฐ๊ตญ, transformer์€ ์ด๋ฏธ์ง€๋ฅผ ํŒจ์น˜๋กœ ์ชผ๊ฐœ๊ณ  ์ผ๋ ฌ๋กœ ์„ธ์›Œ์„œ ํ•™์Šต(์ง€์—ญ์  ํŠน์ง• ๋ฌด์‹œ)ํ•˜๋Š”๋ฐ, ๊ฐ ํ”ฝ์…€๋“ค์˜ ์œ„์น˜์™€ ์ฃผ๋ณ€ ๊ฐ’๋“ค์ด ์ค‘์š”ํ•œ CNN์— ์žˆ์–ด์„œ, ์ด ํŒจ์น˜๋“ค์˜ ์ž…๋ ฅ์ด 1์ฐจ์›์˜ ๊ฐ’์œผ๋กœ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๊ฐ€ ๋œ๋‹ค๋ฉด ์›ํ•˜๋Š” ์ถœ๋ ฅ์ด ๋‚˜์˜ฌ ์ˆ˜ ์—†๊ฒŒ ๋œ๋‹ค๋Š” ๋œป์ธ ๊ฒƒ ๊ฐ™์Šต๋‹ˆ๋‹ค.

ํ•˜์ง€๋งŒ!, ์ด ๋…ผ๋ฌธ์—์„œ๋Š” ์—„์ฒญ๋‚œ ์–‘์˜ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉํ•ด ์ด inductive bias๋ฅผ ์—†์•จ ์ˆ˜ ์žˆ์—ˆ๋‹ค๊ณ  ๋งํ•ฉ๋‹ˆ๋‹ค. ๊ฒฐ๊ตญ.. ๋ฐ์ดํ„ฐ๊ฐ€ ๋งŽ์€๊ฒŒ ํ•œ์ˆ˜ ์˜€๋‚˜๋ด…๋‹ˆ๋‹ค.๐Ÿฑ

image


๐ŸŒ• Method

image

์œ„์˜ ์ด๋ฏธ์ง€์™€ ๊ฐ™์ด ์ด๋ฏธ์ง€๋ฅผ ์ผ์ •ํ•œ ๊ฐ’์˜ patch๋กœ ๋ถ„ํ• ํ•˜๊ณ  ์ด patch ๋“ค์„ embedding์„ ์‹œ์ผœ์„œ transformer์˜ ์ž…๋ ฅ๊ฐ’์œผ๋กœ ์‚ฌ์šฉํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๊ฐœ๋ณ„์˜ patch๋ฅผ NLP์˜ token ์ฒ˜๋Ÿผ ๊ฐ„์ฃผํ•˜๋Š” ๊ฒƒ์ด์ฃ .

ํ”„๋กœ์„ธ์Šค๋ฅผ ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.

image

๋จผ์ € 48x48์˜ ์ด๋ฏธ์ง€๋ฅผ input image๋กœ ํ•œ๋‹ค๋ฉด, (16x16) * 3 ์˜ ํ˜•ํƒœ๋กœ ์ด 9๊ฐœ์˜ patch ํ˜•ํƒœ๋กœ ์ชผ๊ฐญ๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๊ฐ๊ฐ์˜ patch๋“ค์„ linear ํ•˜๊ฒŒ prjection ์‹œ์ผœ์„œ patch embedding์„ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ 768์ฐจ์›์˜ ํ•˜๋‚˜์˜ ๋ฒกํ„ฐ๊ฐ€ ๋˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

16x16x3 = 768 = D


image

๊ทธ ๋‹ค์Œ์—๋Š” ๊ฐ๊ฐ์˜ embedding๋œ ํŒจ์น˜์— PE(position embedding) ๋ฅผ ๋”ํ•ด์ค€ ๊ฐ’์„ ์ตœ์ข… transformer์˜ ์ž…๋ ฅ ๋ฒกํ„ฐ๋กœ ์‚ฌ์šฉํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ์ด๋•Œ, ๋งจ ์•ž์—๋Š” CLASS TOKEN ์ด๋ผ๋Š” ๊ฒƒ์ด ๋ถ™๋Š”๋ฐ์š”, ์ž์—ฐ์–ด ์ฒ˜๋ฆฌ์˜ BERT ๋ชจ๋ธ์—์„œ ๋‚˜์˜จ ๊ฐœ๋…์„ ์‚ฌ์šฉํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค. class token์€ ํ•™์Šต์„ ํ•˜๋ฉด์„œ ๊ฐ ํ† ํฐ๋“ค์— ๋Œ€ํ•œ ์ •๋ณด๊ฐ’์„ ์ €์žฅํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ์ด 0๋ฒˆ์งธ ํŒจ์น˜๋Š” ๊ฐ ํŒจ์น˜๋“ค์˜ ์œ„์น˜๊ฐ’๊ณผ ์ •๋ณด๋“ค์„ ๊ฐ€์ง€๊ณ  ์žˆ์–ด ์‹ค์ œ๋กœ classification์œผ๋กœ ์‹๋ณ„ ํ•  ๋•Œ, ์ด class token์œผ๋กœ ์ง„ํ–‰ํ•œ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ๋‚˜๋จธ์ง€ ํŒจ์น˜๋“ค์„ ํ•™์Šต์šฉ์œผ๋กœ, ์ •๋ณด๋ฅผ ๋น„๊ตํ• ๋•Œ๋Š” ๋งจ ์•ž์˜ ํŒจ์น˜(class token)์œผ๋กœ!


์—ฌ๊ธฐ์„œ ์ž ๊น!๐Ÿ– ๊ธฐ์กด์˜ transformer ๊ณผ vision์—์„œ ์‚ฌ์šฉํ•˜๋Š” transformer์˜ ์ฐจ์ด์— ๋Œ€ํ•ด ์งš๊ณ  ๋„˜์–ด๊ฐ€๊ฒ ์Šต๋‹ˆ๋‹ค.

image

๋‘˜์„ ๋น„๊ตํ–ˆ์„ ๋•Œ! Normalization ์˜ ์œ„์น˜๊ฐ€ ๋‹ค๋ฅธ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. Layer normalization์˜ ์œ„์น˜๊ฐ€ ์ค‘์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์ด ํ›„์† ์—ฐ๊ตฌ๋“ค์—์„œ ๋ฐํ˜€์ง€๊ฒŒ ๋˜๋ฉด์„œ Attention๊ณผ MLP์ด์ „์— ์ •๊ทœํ™”๋ฅผ ์ˆ˜ํ–‰ํ•˜์—ฌ ๋” ๊นŠ์€ ๋ ˆ์ด์–ด์—์„œ๋„ ํ•™์Šต์ด ์ž˜ ๋˜๋„๋ก ํ–ˆ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.


๐ŸŒ Fine-tuning and higher resolution

ViT๋Š” ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๊ธฐ ๋ณด๋‹ค, ๋จผ์ € ํฐ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ pre-train ์„ ์‹œํ‚จ ๋ชจ๋ธ์— fine-tune ํ•˜์—ฌ ์‚ฌ์šฉํ•œ๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ, transformer encoder ์ถœ๋ ฅ๊ฐ’์— ํ•˜๋‚˜์˜ hidden layer๋ฅผ ๊ฐ€์ง„ MLP๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ pre-train ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. fine-tunning ์‹œ์—๋Š” ๋žœ๋ค ์ดˆ๊ธฐํ™”๋œ ํ•˜๋‚˜์˜ linear layer๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ pre-train์—์„œ ํ•™์Šต๋œ positional embedding์€ fine-tunning์‹œ์— ๋‹ค๋ฅธ ํ•ด์ƒ๋„๋กœ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ position embedding์„ ์กฐ์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

ํ•™์Šต ์ดˆ๊ธฐ์— position embedding์€ patch์˜ 2D ์œ„์น˜์— ๋Œ€ํ•œ ์•„๋ฌด๋Ÿฐ ์ •๋ณด๋ฅผ ์ œ๊ณตํ•˜์ง€ ์•Š๊ณ , patch ์‚ฌ์ด์— ๊ณต๊ฐ„์ ์ธ ๊ด€๊ณ„๋Š” ์ฒ˜์Œ๋ถ€ํ„ฐ ํ•™์Šต๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋…ผ๋ฌธ์—์„œ๋Š” ๊ณ ํ•ด์ƒ๋„ ์ด๋ฏธ์ง€๋กœ fine tunning ํ•  ๋•Œ๋Š”, ํ•™์Šต๋œ pre-trained position embedding์— 2D interpolationํ•˜๋Š” ๋ฐฉ๋ฒ•(์ˆ˜๋™์œผ๋กœ bias๋ฅผ ๋„ฃ๋Š” ๋ฐฉ๋ฒ•)์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.

๐ŸŒ hybrid architecture

CNN๊ณผ transformer๋ฅผ ํ•จ๊ป˜ ์‚ฌ์šฉํ•˜๋Š” hybrid archtecture๋„ ์†Œ๊ฐœํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. CNN์˜ feature map์œผ๋กœ๋ถ€ํ„ฐ patch๋ฅผ ์ถ”์ถœํ•˜์—ฌ patch embedding์„ ์ ์šฉํ•˜๋Š” ํ”„๋กœ์„ธ์Šค๋ฅผ ๊ฑฐ์น˜๊ฒŒ ๋˜๋Š”๋ฐ, CNN์œผ๋กœ localtyํ•œ ์ •๋ณด๋ฅผ ์ถ”์ถœํ•˜๊ณ , ์ด ์ •๋ณด๋“ค์˜ patch์˜ sequence๋ฅผ embeddingํ•˜์—ฌ transformer๋กœ ์ „๋‹ฌํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. VIT์— ๋น„ํ•ด ์ข€ ๋” ์ ์€ dataset์œผ๋กœ ์„ฑ๋Šฅ์ด saturateํ•˜๋‹ค๋Š” ์žฅ์ ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํ•˜์ง€๋งŒ ํฐ ๋ฐ์ดํ„ฐ์…‹์„ ์‚ฌ์šฉํ•œ๋‹ค๋ฉด VIT๋ณด๋‹ค ์„ฑ๋Šฅ์ด ๋’ค๋–จ์–ด ์ง‘๋‹ˆ๋‹ค.

์•„๋ž˜๋Š” ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ฃผ๋Š” ๊ทธ๋ž˜ํ”„ ์ž…๋‹ˆ๋‹ค.

image


๐ŸŒ• Experiments

์•„๋ž˜๋Š” ๋‹ค์–‘ํ•œ ํฌ๊ธฐ์˜ ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•ด pre-train์„ ์ง„ํ–‰ํ•˜๊ณ  SOTA ์„ฑ๋Šฅ์„ ๋ณด์ด๋Š” ๊ฒƒ์„ ํ™•์ธ์‹œ์ผœ์ฃผ๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. JFT ๋ฐ์ดํ„ฐ์…‹์— ๋Œ€ํ•œ ์ •๋ณด๋Š” ์—ฌ๊ธฐ!

image

ImageNet Top-1 accuracy ์„ฑ๋Šฅ์„ ๋ณด์•„๋„ pretraining dataset์˜ ํฌ๊ธฐ๊ฐ€ ํด ์ˆ˜๋ก ์„ฑ๋Šฅ์ด ๋” ์ข‹์•„์ง€๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

image

๋˜ํ•œ, JFT ๋ฐ์ดํ„ฐ์˜ ํฌ๊ธฐ๋ฅผ ๋‹ค๋ฅด๊ฒŒ ์ฃผ์–ด ํ•™์Šตํ•œ ๊ฒฐ๊ณผ ์—ญ์‹œ๋‚˜ ๋ฐ์ดํ„ฐ์…‹์ด ํด ์ˆ˜๋ก ์„ฑ๋Šฅ์ด ์ข‹์•„์ง‘๋‹ˆ๋‹ค.

image


์ฐธ๊ณ 

[1] inductive bias

[2] https://www.youtube.com/watch?v=bgsYOGhpxDc

[3] https://deep-learning-study.tistory.com/716

ํƒœ๊ทธ: ,

์นดํ…Œ๊ณ ๋ฆฌ:

์—…๋ฐ์ดํŠธ:

๋Œ“๊ธ€๋‚จ๊ธฐ๊ธฐ