当前位置:  开发笔记 > 开发工具 > 正文

facenet triplet loss with keras

如何解决《facenettripletlosswithkeras》经验,为你挑选了2个好方法。

I am trying to implement facenet in Keras with Thensorflow backend and I have some problem with the triplet loss.在此输入图像描述

I call the fit function with 3*n number of images and then I define my custom loss function as follows:

def triplet_loss(self, y_true, y_pred):

    embeddings = K.reshape(y_pred, (-1, 3, output_dim))

    positive_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,1]),axis=-1)
    negative_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,2]),axis=-1)
    return K.mean(K.maximum(0.0, positive_distance - negative_distance + _alpha))

self._model.compile(loss=triplet_loss, optimizer="sgd")
self._model.fit(x=x,y=y,nb_epoch=1, batch_size=len(x))

where y is just a dummy array filled with 0s

The problem is that even after the first iteration with batch size 20 the model starts predicting the same embedding for all the images. So when I first do the prediction on the batch every embedding is different. Then I do the fit and predict again and suddenly all the embeddings becomes almost the same for all the images in the batch

Also notice that there is a Lambda layer at the end of the model. It normalizes the output of the net so all the embeddings has a unit length as it was suggested in the face net study.

Can anybody help me out here?

模型摘要

    Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_1 (InputLayer)             (None, 224, 224, 3)   0                                            
____________________________________________________________________________________________________
convolution2d_1 (Convolution2D)  (None, 112, 112, 64)  9472        input_1[0][0]                    
____________________________________________________________________________________________________
batchnormalization_1 (BatchNormal(None, 112, 112, 64)  128         convolution2d_1[0][0]            
____________________________________________________________________________________________________
maxpooling2d_1 (MaxPooling2D)    (None, 56, 56, 64)    0           batchnormalization_1[0][0]       
____________________________________________________________________________________________________
convolution2d_2 (Convolution2D)  (None, 56, 56, 64)    4160        maxpooling2d_1[0][0]             
____________________________________________________________________________________________________
batchnormalization_2 (BatchNormal(None, 56, 56, 64)    128         convolution2d_2[0][0]            
____________________________________________________________________________________________________
convolution2d_3 (Convolution2D)  (None, 56, 56, 192)   110784      batchnormalization_2[0][0]       
____________________________________________________________________________________________________
batchnormalization_3 (BatchNormal(None, 56, 56, 192)   384         convolution2d_3[0][0]            
____________________________________________________________________________________________________
maxpooling2d_2 (MaxPooling2D)    (None, 28, 28, 192)   0           batchnormalization_3[0][0]       
____________________________________________________________________________________________________
convolution2d_5 (Convolution2D)  (None, 28, 28, 96)    18528       maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
convolution2d_7 (Convolution2D)  (None, 28, 28, 16)    3088        maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
maxpooling2d_3 (MaxPooling2D)    (None, 28, 28, 192)   0           maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
convolution2d_4 (Convolution2D)  (None, 28, 28, 64)    12352       maxpooling2d_2[0][0]             
____________________________________________________________________________________________________
convolution2d_6 (Convolution2D)  (None, 28, 28, 128)   110720      convolution2d_5[0][0]            
____________________________________________________________________________________________________
convolution2d_8 (Convolution2D)  (None, 28, 28, 32)    12832       convolution2d_7[0][0]            
____________________________________________________________________________________________________
convolution2d_9 (Convolution2D)  (None, 28, 28, 32)    6176        maxpooling2d_3[0][0]             
____________________________________________________________________________________________________
merge_1 (Merge)                  (None, 28, 28, 256)   0           convolution2d_4[0][0]            
                                                                   convolution2d_6[0][0]            
                                                                   convolution2d_8[0][0]            
                                                                   convolution2d_9[0][0]            
____________________________________________________________________________________________________
convolution2d_11 (Convolution2D) (None, 28, 28, 96)    24672       merge_1[0][0]                    
____________________________________________________________________________________________________
convolution2d_13 (Convolution2D) (None, 28, 28, 32)    8224        merge_1[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_4 (MaxPooling2D)    (None, 28, 28, 256)   0           merge_1[0][0]                    
____________________________________________________________________________________________________
convolution2d_10 (Convolution2D) (None, 28, 28, 64)    16448       merge_1[0][0]                    
____________________________________________________________________________________________________
convolution2d_12 (Convolution2D) (None, 28, 28, 128)   110720      convolution2d_11[0][0]           
____________________________________________________________________________________________________
convolution2d_14 (Convolution2D) (None, 28, 28, 64)    51264       convolution2d_13[0][0]           
____________________________________________________________________________________________________
convolution2d_15 (Convolution2D) (None, 28, 28, 64)    16448       maxpooling2d_4[0][0]             
____________________________________________________________________________________________________
merge_2 (Merge)                  (None, 28, 28, 320)   0           convolution2d_10[0][0]           
                                                                   convolution2d_12[0][0]           
                                                                   convolution2d_14[0][0]           
                                                                   convolution2d_15[0][0]           
____________________________________________________________________________________________________
convolution2d_16 (Convolution2D) (None, 28, 28, 128)   41088       merge_2[0][0]                    
____________________________________________________________________________________________________
convolution2d_18 (Convolution2D) (None, 28, 28, 32)    10272       merge_2[0][0]                    
____________________________________________________________________________________________________
convolution2d_17 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_16[0][0]           
____________________________________________________________________________________________________
convolution2d_19 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_18[0][0]           
____________________________________________________________________________________________________
maxpooling2d_5 (MaxPooling2D)    (None, 14, 14, 320)   0           merge_2[0][0]                    
____________________________________________________________________________________________________
merge_3 (Merge)                  (None, 14, 14, 640)   0           convolution2d_17[0][0]           
                                                                   convolution2d_19[0][0]           
                                                                   maxpooling2d_5[0][0]             
____________________________________________________________________________________________________
convolution2d_21 (Convolution2D) (None, 14, 14, 96)    61536       merge_3[0][0]                    
____________________________________________________________________________________________________
convolution2d_23 (Convolution2D) (None, 14, 14, 32)    20512       merge_3[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_6 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_3[0][0]                    
____________________________________________________________________________________________________
convolution2d_20 (Convolution2D) (None, 14, 14, 256)   164096      merge_3[0][0]                    
____________________________________________________________________________________________________
convolution2d_22 (Convolution2D) (None, 14, 14, 192)   166080      convolution2d_21[0][0]           
____________________________________________________________________________________________________
convolution2d_24 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_23[0][0]           
____________________________________________________________________________________________________
convolution2d_25 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_6[0][0]             
____________________________________________________________________________________________________
merge_4 (Merge)                  (None, 14, 14, 640)   0           convolution2d_20[0][0]           
                                                                   convolution2d_22[0][0]           
                                                                   convolution2d_24[0][0]           
                                                                   convolution2d_25[0][0]           
____________________________________________________________________________________________________
convolution2d_27 (Convolution2D) (None, 14, 14, 112)   71792       merge_4[0][0]                    
____________________________________________________________________________________________________
convolution2d_29 (Convolution2D) (None, 14, 14, 32)    20512       merge_4[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_7 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_4[0][0]                    
____________________________________________________________________________________________________
convolution2d_26 (Convolution2D) (None, 14, 14, 224)   143584      merge_4[0][0]                    
____________________________________________________________________________________________________
convolution2d_28 (Convolution2D) (None, 14, 14, 224)   226016      convolution2d_27[0][0]           
____________________________________________________________________________________________________
convolution2d_30 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_29[0][0]           
____________________________________________________________________________________________________
convolution2d_31 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_7[0][0]             
____________________________________________________________________________________________________
merge_5 (Merge)                  (None, 14, 14, 640)   0           convolution2d_26[0][0]           
                                                                   convolution2d_28[0][0]           
                                                                   convolution2d_30[0][0]           
                                                                   convolution2d_31[0][0]           
____________________________________________________________________________________________________
convolution2d_33 (Convolution2D) (None, 14, 14, 128)   82048       merge_5[0][0]                    
____________________________________________________________________________________________________
convolution2d_35 (Convolution2D) (None, 14, 14, 32)    20512       merge_5[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_8 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_5[0][0]                    
____________________________________________________________________________________________________
convolution2d_32 (Convolution2D) (None, 14, 14, 192)   123072      merge_5[0][0]                    
____________________________________________________________________________________________________
convolution2d_34 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_33[0][0]           
____________________________________________________________________________________________________
convolution2d_36 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_35[0][0]           
____________________________________________________________________________________________________
convolution2d_37 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_8[0][0]             
____________________________________________________________________________________________________
merge_6 (Merge)                  (None, 14, 14, 640)   0           convolution2d_32[0][0]           
                                                                   convolution2d_34[0][0]           
                                                                   convolution2d_36[0][0]           
                                                                   convolution2d_37[0][0]           
____________________________________________________________________________________________________
convolution2d_39 (Convolution2D) (None, 14, 14, 144)   92304       merge_6[0][0]                    
____________________________________________________________________________________________________
convolution2d_41 (Convolution2D) (None, 14, 14, 32)    20512       merge_6[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_9 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_6[0][0]                    
____________________________________________________________________________________________________
convolution2d_38 (Convolution2D) (None, 14, 14, 160)   102560      merge_6[0][0]                    
____________________________________________________________________________________________________
convolution2d_40 (Convolution2D) (None, 14, 14, 288)   373536      convolution2d_39[0][0]           
____________________________________________________________________________________________________
convolution2d_42 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_41[0][0]           
____________________________________________________________________________________________________
convolution2d_43 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_9[0][0]             
____________________________________________________________________________________________________
merge_7 (Merge)                  (None, 14, 14, 640)   0           convolution2d_38[0][0]           
                                                                   convolution2d_40[0][0]           
                                                                   convolution2d_42[0][0]           
                                                                   convolution2d_43[0][0]           
____________________________________________________________________________________________________
convolution2d_44 (Convolution2D) (None, 14, 14, 160)   102560      merge_7[0][0]                    
____________________________________________________________________________________________________
convolution2d_46 (Convolution2D) (None, 14, 14, 64)    41024       merge_7[0][0]                    
____________________________________________________________________________________________________
convolution2d_45 (Convolution2D) (None, 7, 7, 256)     368896      convolution2d_44[0][0]           
____________________________________________________________________________________________________
convolution2d_47 (Convolution2D) (None, 7, 7, 128)     204928      convolution2d_46[0][0]           
____________________________________________________________________________________________________
maxpooling2d_10 (MaxPooling2D)   (None, 7, 7, 640)     0           merge_7[0][0]                    
____________________________________________________________________________________________________
merge_8 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_45[0][0]           
                                                                   convolution2d_47[0][0]           
                                                                   maxpooling2d_10[0][0]            
____________________________________________________________________________________________________
convolution2d_49 (Convolution2D) (None, 7, 7, 192)     196800      merge_8[0][0]                    
____________________________________________________________________________________________________
convolution2d_51 (Convolution2D) (None, 7, 7, 48)      49200       merge_8[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_11 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_8[0][0]                    
____________________________________________________________________________________________________
convolution2d_48 (Convolution2D) (None, 7, 7, 384)     393600      merge_8[0][0]                    
____________________________________________________________________________________________________
convolution2d_50 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_49[0][0]           
____________________________________________________________________________________________________
convolution2d_52 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_51[0][0]           
____________________________________________________________________________________________________
convolution2d_53 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_11[0][0]            
____________________________________________________________________________________________________
merge_9 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_48[0][0]           
                                                                   convolution2d_50[0][0]           
                                                                   convolution2d_52[0][0]           
                                                                   convolution2d_53[0][0]           
____________________________________________________________________________________________________
convolution2d_55 (Convolution2D) (None, 7, 7, 192)     196800      merge_9[0][0]                    
____________________________________________________________________________________________________
convolution2d_57 (Convolution2D) (None, 7, 7, 48)      49200       merge_9[0][0]                    
____________________________________________________________________________________________________
maxpooling2d_12 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_9[0][0]                    
____________________________________________________________________________________________________
convolution2d_54 (Convolution2D) (None, 7, 7, 384)     393600      merge_9[0][0]                    
____________________________________________________________________________________________________
convolution2d_56 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_55[0][0]           
____________________________________________________________________________________________________
convolution2d_58 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_57[0][0]           
____________________________________________________________________________________________________
convolution2d_59 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_12[0][0]            
____________________________________________________________________________________________________
merge_10 (Merge)                 (None, 7, 7, 1024)    0           convolution2d_54[0][0]           
                                                                   convolution2d_56[0][0]           
                                                                   convolution2d_58[0][0]           
                                                                   convolution2d_59[0][0]           
____________________________________________________________________________________________________
averagepooling2d_1 (AveragePoolin(None, 1, 1, 1024)    0           merge_10[0][0]                   
____________________________________________________________________________________________________
flatten_1 (Flatten)              (None, 1024)          0           averagepooling2d_1[0][0]         
____________________________________________________________________________________________________
dense_1 (Dense)                  (None, 128)           131200      flatten_1[0][0]                  
____________________________________________________________________________________________________
lambda_1 (Lambda)                (None, 128)           0           dense_1[0][0]                    
====================================================================================================
Total params: 7456944
____________________________________________________________________________________________________
None

小智.. 8

除了学习率太高之外,可能发生的事情是,有效地使用了不稳定的三元组选择策略.例如,如果你只使用'硬三元组'(距离小于ap距离的三元组),你的网络权重可能会将所有嵌入都折叠到一个点(使得损失总是等于保证金(你的_alpha)),因为所有嵌入距离均为零).

这也可以通过使用其他类型的三元组来修复(如'半硬三元组',其中ap小于a,但ap和a之间的距离仍然小于边距).所以,如果你总是检查这个...可以在这篇博文中详细解释:https://omoindrot.github.io/triplet-loss



1> 小智..:

除了学习率太高之外,可能发生的事情是,有效地使用了不稳定的三元组选择策略.例如,如果你只使用'硬三元组'(距离小于ap距离的三元组),你的网络权重可能会将所有嵌入都折叠到一个点(使得损失总是等于保证金(你的_alpha)),因为所有嵌入距离均为零).

这也可以通过使用其他类型的三元组来修复(如'半硬三元组',其中ap小于a,但ap和a之间的距离仍然小于边距).所以,如果你总是检查这个...可以在这篇博文中详细解释:https://omoindrot.github.io/triplet-loss



2> Chris Anders..:

你是否限制你的嵌入"处于d维超球面"?尝试tf.nn.l2_normalize在CNN出来后立即运行嵌入.

问题可能在于嵌入有点像智能无线电.减少损失的一种简单方法是将所有内容设置为零.l2_normalize迫使它们成为单位长度.

它看起来你想要在最后一个平均池之后添加正常化.

推荐阅读
拾味湖
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有