[CODE]Convolutional relation network for facial expression recognition in the wild with few-shot learningn
Convolutional relation network for facial expression recognition in the wild with few-shot learningPermalink
- CRN -Permalink
논문원본😙
오늘은 얼굴 이미지 데이터를 few-shot 학습기법으로 학습하여 얼굴데이터셋에 적절한 loss function을 고안한 논문의 코드를 구현해 보았다.
DAP 로 쿼리 이미지들의 feature 값을 pooling 해준 값으로 모든 로스를 계산한다.
-
첫번째 로스는(loss_d) JS-Divergence 를 이용해서 feature embedding 에서 뽑은 값으로 query batch 와 DAP를 계산해준 sample batch(support) 간의 mse loss 를 구해준 후
-
두번째 로스는(loss_r) query batch 와 DAP를 계산해준 sample batch(support) 를 relation network 를 통해 뽑은 feature 끼리의 mse 로 계산된다.
-
최종으로 사용하는 로스는 loss_d + lambda(loss_r) 을 전체 샘플 갯수만큼 나눠준 값으로 모델 값을 update 해준다.
구현한 코드는 다음과 같다.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def mean_confidence_interval(data, confidence=0.95): | |
a = 1.0*np.array(data) | |
n = len(a) | |
m, se = np.mean(a), scipy.stats.sem(a) | |
h = se * sp.stats.t._ppf((1+confidence)/2., n-1) | |
return m,h | |
class CNNEncoder(nn.Module): | |
"""docstring for ClassName""" | |
def __init__(self): | |
super(CNNEncoder, self).__init__() | |
self.layer1 = nn.Sequential( | |
nn.Conv2d(3,64,kernel_size=3,padding=0), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(), | |
nn.MaxPool2d(2)) | |
self.layer2 = nn.Sequential( | |
nn.Conv2d(64,64,kernel_size=3,padding=0), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(), | |
nn.MaxPool2d(2)) | |
self.layer3 = nn.Sequential( | |
nn.Conv2d(64,64,kernel_size=3,padding=1), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU()) | |
self.layer4 = nn.Sequential( | |
nn.Conv2d(64,64,kernel_size=3,padding=1), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU()) | |
def forward(self,x): | |
out = self.layer1(x) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
out = self.layer4(out) | |
#out = out.view(out.size(0),-1) | |
return out # 64 | |
class CNNEncoder_LOSS(nn.Module): | |
"""docstring for ClassName""" | |
def __init__(self): | |
super(CNNEncoder_LOSS, self).__init__() | |
self.layer1 = nn.Sequential( | |
nn.Conv2d(3,64,kernel_size=3,padding=0), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(), | |
nn.MaxPool2d(2)) | |
self.layer2 = nn.Sequential( | |
nn.Conv2d(64,64,kernel_size=3,padding=0), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(), | |
nn.MaxPool2d(2)) | |
self.layer3 = nn.Sequential( | |
nn.Conv2d(64,64,kernel_size=3,padding=1), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU()) | |
self.layer4 = nn.Sequential( | |
nn.Conv2d(64,64,kernel_size=3,padding=1), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU()) | |
def forward(self,x): | |
out = self.layer1(x) | |
out = self.layer2(out) | |
out = self.layer3(out) | |
out = self.layer4(out) | |
out = out.view(out.size(0),-1) | |
return out # 64 | |
class RelationNetwork(nn.Module): | |
"""docstring for RelationNetwork""" | |
def __init__(self,input_size,hidden_size): | |
super(RelationNetwork, self).__init__() | |
self.layer1 = nn.Sequential( | |
nn.Conv2d(64*2,64,kernel_size=3,padding=0), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(), | |
nn.MaxPool2d(2)) | |
self.layer2 = nn.Sequential( | |
nn.Conv2d(64,64,kernel_size=3,padding=0), | |
nn.BatchNorm2d(64, momentum=1, affine=True), | |
nn.ReLU(), | |
nn.MaxPool2d(2)) | |
self.fc1 = nn.Linear(input_size*3*3,hidden_size) | |
self.fc2 = nn.Linear(hidden_size,1) | |
def forward(self,x): | |
out = self.layer1(x) | |
out = self.layer2(out) | |
out = out.view(out.size(0),-1) | |
out = F.relu(self.fc1(out)) | |
out = F.sigmoid(self.fc2(out)) | |
return out | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
m.weight.data.normal_(0, math.sqrt(2. / n)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
elif classname.find('Linear') != -1: | |
n = m.weight.size(1) | |
m.weight.data.normal_(0, 0.01) | |
m.bias.data = torch.ones(m.bias.data.size()) | |
def JSD(P, Q): | |
kld = KLDivLoss().cuda() | |
M = 0.5 * (P + Q) | |
return 0.5 * (kld(P, M) + kld(Q, M)) | |
def main(): | |
# Step 1: init data folders | |
print("init data folders") | |
# init character folders for dataset construction | |
metatrain_folders,metatest_folders = tg.mini_imagenet_folders() | |
# Step 2: init neural networks | |
print("init neural networks") | |
feature_encoder = CNNEncoder() | |
feature_encoder_l = CNNEncoder_LOSS() | |
relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM) | |
feature_encoder.apply(weights_init) | |
feature_encoder_l.apply(weights_init) | |
relation_network.apply(weights_init) | |
feature_encoder.cuda(GPU) | |
feature_encoder_l.cuda(GPU) | |
relation_network.cuda(GPU) | |
feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) | |
feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5) | |
feature_encoder_l_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) | |
feature_encoder_l_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5) | |
relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) | |
relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5) | |
if os.path.exists(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): | |
feature_encoder.load_state_dict(torch.load(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) | |
feature_encoder_l.load_state_dict(torch.load(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) | |
print("load feature encoder success") | |
if os.path.exists(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): | |
relation_network.load_state_dict(torch.load(str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) | |
print("load relation network success") | |
# Step 3: build graph | |
print("Training...") | |
last_accuracy = 0.0 | |
for episode in range(EPISODE): | |
feature_encoder_scheduler.step(episode) | |
feature_encoder_l_scheduler.step(episode) | |
relation_network_scheduler.step(episode) | |
# init dataset | |
# sample_dataloader is to obtain previous samples for compare | |
# batch_dataloader is to batch samples for training | |
task = tg.MiniImagenetTask(metatrain_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS) | |
sample_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False) | |
batch_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=BATCH_NUM_PER_CLASS,split="test",shuffle=True) | |
# sample data => support, batch data = query | |
samples,sample_labels = sample_dataloader.__iter__().next() #25*3*84*84 | |
batches,batch_labels = batch_dataloader.__iter__().next() | |
# calculate features | |
sample_features = feature_encoder(Variable(samples).cuda(GPU)) # 25*64*19*19 | |
sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,19,19) # 5*5*64*19*19 | |
sample_features_l = feature_encoder_l(Variable(samples).cuda(GPU)) | |
sample_features_l = sample_features_l.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, 23104) # 5*5*64*19*19 | |
## each features per class | |
each_class_feature = [] | |
each_class_feature_l = [] | |
for i in range(CLASS_NUM): | |
each_class_feature.append(sample_features[i]) | |
each_class_feature_l.append(sample_features_l[i]) | |
#list to tensor | |
each_class_feature = torch.stack(each_class_feature, dim=0) | |
each_class_feature = torch.transpose(each_class_feature,1,2) | |
each_class_feature_l = torch.stack(each_class_feature_l, dim=0) | |
## DAP for support(sample) features | |
m = nn.AdaptiveAvgPool3d((1, None, None)) | |
DAP = m(each_class_feature) | |
DAP = DAP.squeeze(2) | |
## DAP for loss feature | |
ml = nn.AdaptiveAvgPool2d((1, None)) | |
DAP_L = ml(each_class_feature_l) #torch.Size([5, 1, 23104]) | |
DAP_L = DAP_L.squeeze(1) # torch.Size([5, 23104]) | |
#sample_features = torch.sum(sample_features,1).squeeze(1) # 5*64*19*19 | |
batch_features = feature_encoder(Variable(batches).cuda(GPU)) # 20x64*5*5 | |
batch_features_l = feature_encoder_l(Variable(batches).cuda(GPU)) | |
# calculate relations | |
# each batch sample link to every samples to calculate relations | |
# to form a 100x128 matrix for relation network | |
sample_features_ext = DAP.unsqueeze(0).repeat(BATCH_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) | |
batch_features_ext = batch_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) | |
sample_features_ext_l = DAP_L.unsqueeze(0).repeat(BATCH_NUM_PER_CLASS*CLASS_NUM,1,1) | |
batch_features_ext_l = batch_features_l.unsqueeze(0).repeat(CLASS_NUM, 1, 1) | |
# swap 1st and 2nd dimension | |
batch_features_ext = torch.transpose(batch_features_ext,0,1) | |
# 𝐷(𝑖,𝑗)𝐽 𝑆 (𝑃 (𝑓 𝑎 𝜃 (𝑠(𝑖) 𝑘 )), 𝑃 (𝑓𝜃 (𝑞(𝑗)𝑘 ))) | |
#js_divergence = JSD(sample_features_ext, batch_features_ext) | |
#print(js_divergence) | |
#each_class_feature[0].std(dim=0).shape | |
#batch_features_ext.mean(dim=0).shape | |
mse = nn.MSELoss().cuda(GPU) | |
# mean sample features of each class ==> torch.Size([5, 23104]) | |
mean_sample_each_class = sample_features_ext_l.transpose(0,1).mean(dim=2) | |
# mean batch features of each class ==> torch.Size([5, 23104]) | |
mean_batch_each_class = batch_features_ext_l.mean(dim=2) | |
every_sample_loss = [] | |
for i in range(CLASS_NUM): | |
for j in range(BATCH_NUM_PER_CLASS*CLASS_NUM): | |
every_sample_loss.append(JSD(mean_sample_each_class[i][j], mean_batch_each_class[i][j])) | |
every_sample_torch = torch.stack(every_sample_loss, dim=0) | |
loss_each_class = every_sample_torch.view(-1,CLASS_NUM) | |
one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1).cuda(GPU)) | |
loss_d = mse(loss_each_class,one_hot_labels) | |
# concat support and query | |
rr = torch.cat((sample_features_ext,batch_features_ext),2) | |
# rr -> cat(features,2) 2로 해주면 5차원에서 가운데 즉, channel 값이 더해진다. | |
# rr.shape == 50x5x128x19x19 | |
relation_pairs = torch.cat((sample_features_ext,batch_features_ext),2).view(-1,FEATURE_DIM*2,19,19) | |
# relation_pairs.shape == 250x128x19x19 | |
#relations represent the similarities between matched pairs ==> torch.Size([50, 5]) | |
relations = relation_network(relation_pairs).view(-1,CLASS_NUM) | |
# relation_network(relation_pairs).shape => [250,1] | |
# view(-1, *) ==> (?, *) | |
loss_r = mse(relations,one_hot_labels) | |
lam = 0.5 | |
loss = loss_d + lam * loss_r | |
#loss = 1/(BATCH_NUM_PER_CLASS * CLASS_NUM) * loss | |
# training | |
feature_encoder.zero_grad() | |
relation_network.zero_grad() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm(feature_encoder.parameters(),0.5) | |
torch.nn.utils.clip_grad_norm(relation_network.parameters(),0.5) | |
feature_encoder_optim.step() | |
relation_network_optim.step() | |
if (episode+1)%100 == 0: | |
print("episode:", episode + 1, "loss", loss.item()) | |
#original code below | |
#print("episode:",episode+1,"loss",loss.data[0]) | |
if episode%5000 == 0: | |
# test | |
print("Testing...") | |
accuracies = [] | |
for i in range(TEST_EPISODE): | |
total_rewards = 0 | |
task = tg.MiniImagenetTask(metatest_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,15) | |
sample_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False) | |
num_per_class = 5 | |
test_dataloader = tg.get_mini_imagenet_data_loader(task,num_per_class=num_per_class,split="test",shuffle=False) | |
sample_images,sample_labels = sample_dataloader.__iter__().next() | |
for test_images,test_labels in test_dataloader: | |
batch_size = test_labels.shape[0] | |
# calculate features | |
sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64 | |
sample_features = sample_features.view(CLASS_NUM,SAMPLE_NUM_PER_CLASS,FEATURE_DIM,19,19) | |
## each features per class | |
each_class_feature = [] | |
for i in range(CLASS_NUM): | |
each_class_feature.append(sample_features[i]) | |
# list to tensor | |
each_class_feature = torch.stack(each_class_feature, dim=0) | |
each_class_feature = torch.transpose(each_class_feature, 1, 2) | |
## DAP for support(sample) features | |
m = nn.AdaptiveAvgPool3d((1, None, None)) | |
DAP = m(each_class_feature) | |
DAP = DAP.squeeze(2) | |
#sample_features = torch.sum(sample_features,1).squeeze(1) | |
test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64 | |
# calculate relations | |
# each batch sample link to every samples to calculate relations | |
# to form a 100x128 matrix for relation network | |
sample_features_ext = DAP.unsqueeze(0).repeat(batch_size,1,1,1,1) | |
test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1) | |
test_features_ext = torch.transpose(test_features_ext,0,1) | |
relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,19,19) | |
relations = relation_network(relation_pairs).view(-1,CLASS_NUM) | |
_,predict_labels = torch.max(relations.data,1) | |
rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)] | |
total_rewards += np.sum(rewards) | |
accuracy = total_rewards/1.0/CLASS_NUM/15 | |
accuracies.append(accuracy) | |
test_accuracy,h = mean_confidence_interval(accuracies) | |
print("test accuracy:",test_accuracy,"h:",h) | |
if test_accuracy > last_accuracy: | |
# save networks | |
torch.save(feature_encoder.state_dict(),str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")) | |
torch.save(relation_network.state_dict(),str("./models/miniimagenet_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")) | |
print("save networks for episode:",episode) | |
last_accuracy = test_accuracy | |
if __name__ == '__main__': | |
main() |
댓글남기기