在筆者對于對比學習的認識中,主要有2個維度的事情需要考慮:
如何選取合適的負樣本如何選取合適的損失函數以下結合一些訓練經驗,簡要筆記下。
如何構造合適的負樣本之前在[1]中簡單介紹過一些構造負樣本的方法,總體來說,基于用戶行為數據我們可以通過batch negative和無點數據進行負樣本構建。
batch negative在搜索過程中,用戶行為存在很大的隨機性,比如有展現但沒有點擊的數據并不一定就是負樣本,為了獲取更可靠的用戶數據,我們可以選擇在用戶點擊過的Doc之間組成負樣本。沒錯,我們認為用戶點擊過的行為是更為可靠的,雖然即便是點擊行為也可能只是因為用戶的好奇行為或者誤操作等等,但是對比于無點行為總歸是更為可靠的。假設用戶的query i ii和點擊過的Doc組成二元組,其中的C \mathcal{C}C表示所有有點行為的集合,那么我們認為其負樣本就是
。當數據足夠龐大是,有點數據
的規模也會非常龐大,我們無法一次將所有負樣本都列舉出來(同時,也沒有必要),我們通常會在一個batch內對所有用戶點擊二元組進行組合。也即是將
的規模限制在一個batch內,如Fig 1.1所示,其中的對角線都是二元組正樣本,而其他元素都是負樣本。通過一個矩陣乘法,我們就可以實現這個操作。如式子(1.1)所示。
Fig 1.1 Batch Negative的方式從一個batch中構造負樣本。
無點擊樣本無點數據也不是一無是處,在某些搜索產品中,如果排序到前面的結果本身就不夠好,那么用戶的點擊數據和無點擊數據就具有足夠的區分度,無點數據拿來視為負樣本就是合理的,這個和產品具體的設計,或者呈現UI形式等等有關,需要在實踐中才能實驗出來。
在實踐中,通常還會去進行batch negative和無點擊數據的混合以達到獲取足夠多的負樣本的目的。
使用何種損失函數常用在對比學習中的損失函數主要有兩種,hinge loss[2]和交叉熵損失。其中的hinge loss形式如(2.1)所示:
hinge loss和SVM一樣[3],存在一個margin,一旦正樣本和負樣本打分的差距超過這個margin,那么損失就變為0,通過這種手段可以讓hinge loss學習到正樣本和負樣本之間的表征區別,而且又可以更好地控制訓練過程。而交叉熵損失是我們的老朋友了,如式子(2.2)所示
其中的N為樣本數量,M為分類類別數量,而則是預測的logit經過softmax之后的概率分布。注意到,正如Fig 1.1所示,對于每個
而言,其每一行都有
個負樣本;對于每個
而言,其每一列都有
個負樣本,那么就可以組織雙向的損失函數計算。這種方式對于雙塔模型結構來說特別地“劃算”,因為對于雙塔模型而言只需要計算一次矩陣計算就可以得到
的打分矩陣,然后通過雙向計算損失,可以實現更高效地對模型的訓練。
在hinge loss計算過程中,還可以通過在這每一行(或者每一列)的N − 1個負樣本中選擇一個最難的負樣本,也就是打分最高的負樣本。這一點很容易理解,負樣本的打分如果打得很高,那么就可以認為模型很大程度地將其誤認為正樣本了,如果能將其分開,那么模型的表征能力應該是更上一層樓的,因此將最難負樣本作為式子(2.1)的進行訓練。
訓練過程在對比學習訓練過程中,我們暫時只考慮雙塔模型(因為交互式模型的負樣本選取策略不同),雖然理論上hinge loss這種基于pairwise樣本選取策略的損失,可以很好地對比正負樣本的表征區別,但是如果模型并沒有進行很好地訓練就拿去用hinge loss進行訓練,有可能因為負樣本太難導致訓練出現“損失坍縮”(loss collapse)的現象,此時模型對正樣本和負樣本沒有區分能力,因此對兩者的打分都極為相似,有,此時loss坍縮到margin并且恒等于margin不再變化,如Fig 3.1所示。我們可以認為模型陷入了平凡解。
Fig 3.1 采用hinge loss導致損失坍縮的現象。圖省事就直接ipad上畫了,有點丑見諒:_)這個現象也不一定就會出現,如果采用的模型已經進行過合適的初始化,就不一定會出現這個問題。另外,采用交叉熵損失進行一開始的訓練是一種比較穩定的方法。在CLIP模型中[4],作者采用了batch size=32,768的配置,在進行過allgather機制,對所有特征進行匯聚后[5],甚至可以實現32768 × 32768 大小的打分矩陣,這意味著有著海量的負樣本可供學習,這也同時意味著對模型學習的巨大挑戰。因此CLIP文章的作者沒有采用hinge loss訓練,而是采用了雙向的交叉熵損失進行訓練。
然而在巨大的batch size中訓練是有著非常大的誘惑的,在[1]中我們就曾經討論過對于對比學習中,負樣本增多意味著表征詞典的詞表的增大,有著巨大的效果增益。那么要如何去訓練這種超大規模的batch size下的對比學習任務呢?筆者個人認為需要進行階段式地訓練,一步步提高batch size大小。筆者曾經試驗過,如果一開始就采用很大的batch size進行訓練,在hinge loss的情況下,將會非常不穩定,很容易出現損失坍塌的現象。而如果循序漸進則不會出現這個問題,那么是否可以通過這種方法將batch size增加到很大呢(不考慮硬件的約束情況),這個筆者也還在實驗,希望后續能有個比較正向的結論。
同時,在超大規模的對比學習過程中,如何結合交叉熵損失和hinge loss損失也是一個值得思考的問題。交叉熵損失穩定,但是學習速度較慢(筆者實驗發現,不一定準確),hinge loss不穩定,但是學習速度更快,如何進行平衡是一個值得嘗試的方向。對比學習在大規模數據上的訓練的確還有很多值得探索的呢,公開的論文提供的細節也不多。
Reference
[1]. https://fesian.blog.csdn.net/article/details/119515146
[2]. https://blog.csdn.net/LoseInVain/article/details/103995962
[3]. https://blog.csdn.net/LoseInVain/article/details/78636176
[4]. https://fesian.blog.csdn.net/article/details/119516894
[5]. https://medium.com/@cresclux/example-on-torch-distributed-gather-7b5921092