基于對比學習的圖文預訓練方式,自從CLIP [1] 橫空出世后,就成為了圖文預訓練的主流方式,引申出了一系列的工作,如ALIGN [3]、FLIP [4]、LiT [5]等。這些工作在數據使用、訓練效率等上進行了探索,但是其核心損失還是采用了infoNCE,也即是對比型的損失。在SigLIP [2] 中,作者提出了基于sigmoid損失的圖文預訓練方式,并且指出采用sigmoid損失能帶來更高效的圖文預訓練效率和效果。在此之前,我們有必要再復習下CLIP的基本思想。CLIP是一個雙塔結構,分別有圖片塔f(⋅)和文本塔g(⋅),那么損失可以表達為式子(1),其中的xi=f(Ii)||f(Ii)||2和yi=g(Ti)||g(Ti)||2,是圖片特征和文本特征的L2 normalization后的結果,t=exp?(t′)為溫度系數,其中t′為可學習參數,B為一個批次(batch)的數據。
其基本思想就是從打分矩陣中,從i->t
和t->i
的方向去判斷出正樣本的位置(也就是對角線的位置),注意到由于采用的是softmax形式去歸一化正負樣本,將其視為了概率分布,因此正負樣本之間的概率關系是耦合在一起的,在提高正樣本概率的同時,勢必會壓低負樣本的概率。
Fig 1. CLIP的基本結構由圖片塔和文本塔組成,打分矩陣的對角線為正樣本,從i2t和t2i的方向分別計算infoNCE損失。
而在SigLIP中,損失函數為式子(2)所示,其中的zij為給定圖片文本對的標簽,當為成對的正樣本時候zij=1,當不是成對的負樣本時候zij=−1。此時對于正負樣本來說是解耦的,增加正樣本的概率并不意味著壓低負樣本的概率。負樣本數量的絕對占優,會導致在訓練初期負樣本的損失主導了整個損失,因此引入了一個可學習的偏置項b去緩解初始階段的訓練困難問題,此處的b在原文中被初始化為-10,這也容易理解,初始的logit減去一個較大的值(如-10),使得正負樣本logit的差別相對不會很大,在正負樣本數量極大不均勻的情況下,可以讓初始狀態更加均勻,從而不會帶來過度調整。
整個損失的建模,如以下代碼所示:
# img_emb : image model embedding [n, dim]
# txt_emb : text model embedding [n, dim]
# t_prime, b : learnable temperature and bias
# n : mini-batch size
t = exp(t_prime)
zimg = l2_normalize(img_emb)
ztxt = l2_normalize(txt_emb)
logits = dot(zimg, ztxt.T) * t + b
labels = 2 * eye(n) - ones(n) # -1 with diagonal 1
l = -sum(log_sigmoid(labels * logits)) / n
這個就是SigLIP的核心優化點,我們先不考慮這個建模的模型效果,先看到這種建模方式帶來的模型訓練的優勢。
- 在CLIP中,正負樣本是pairwise建模的:由于采用的softmax函數去建模正負樣本之間的關系,而CLIP訓練的global batch size一般都很大(如32k),這意味著GPU #1上的正樣本需要見到其他所有GPU上的樣本,并以此作為負樣本。因此通常都需要匯聚多節點多卡的特征向量,這個時候需要頻繁地調用
all_gather
,帶來了沉重的通信代價,也會拖慢整個訓練過程的速度。 - 在SigLIP中,正負樣本是pointwise建模的:采用的sigmoid loss是獨立對每個正負樣本進行計算的,最后再進行loss的累加,這意味著可以在本地完成大部分的計算,在涉及到本地的正樣本和其他設備的負樣本進行交互計算的時候,僅需要很少的
gather
操作就能完成設備間向量的交換就可以(對于圖文預訓練來說,交換文本特征向量即可,通信代價很低),而不需要all_gather
操作。
我們著重介紹下SigLIP是如何進行分布式訓練的,假設全局的batch size為B,一共有D個GPU,那么每個GPU上的batch size為b=B/D,可以將公式(2)的損失拆解為公式(3)所示,在Fig 2. 展示了整個過程的示意圖,在初始化階段,我們以第一個GPU為例子,其所包含的樣本為:
此時GPU 1可以完成一次公式(3)中的C計算,然后,交換GPU 1和GPU 2的文本編碼器特征向量,既是:
此時GPU 2完成一次公式(3)中的B計算,以此類推,直到GPU 1遍歷完所有樣本為止,其他GPU也是如此操作的,最終把所有卡上的損失匯聚即可,也就是A計算。這個輪流交換不同GPU之間數據的操作,可以稱之為permutation。不難發現,整個過程的通信成本來自于permutation,一共需要D−1次gather
操作即可完成一次permutation,而在CLIP中需要對圖文的編碼器特征都進行匯聚,因此需要2次all-gather
操作。如果all-gather
采用ring的話,那么一個all-gather
就是D−1次gather
操作。由此我們得出一個SigLIP和CLIP的性能復雜度對比:
容易發現,SigLIP無論從通信復雜度,儲存復雜度還是計算復雜度上,都遠比CLIP更為優越。
Fig 2. SigLIP高效的損失計算示意圖,假設有3個設備,每個設備上的batch size為4,global batch size為12。
讓我們再關注到SigLIP的模型能力表現,作者主要對比的是SigLIP,以及將圖片表征固定的SigLiT(從而可以將batch size設置到非常大,比如100萬)以及CLIP的表現。我們都知道在CLIP中采用對比損失,意味著越大的batch size能極大地提高對比效率,從而提升效果,受限于softmax的內存占用情況和GPU卡數等原因,無法將batch size設置得很大,在SigLiT中則可以將batch size設置到百萬以上,從而探索極大batch size情況下的收益。如Fig 3.所示,作者對比了三種模型在batch size進行尺度放大后的0-shot能力,訓練量都是18B的數據量,容易發現幾點結論:
- 在batch size小于32k的時候,采用sigmoid的SigLIP的性能都會優于采用softmax的CLIP。
- 在batch size足夠大的時候,CLIP能夠追上,甚至超越SigLIP的表現,但是最佳性能仍然是SigLIP@32k情況下得到,從實際應用來看,采用SigLIP能用更少的資源更快的訓練速度得到更好的性能。
- 從SigLiT的實驗來看,隨著batch size的尺度放大性能將會在32k batch size的情況下達到飽和,同時能觀察到SigLiT在不同batch size下持續優于LiT。繼續提高batch size將不能帶來更好的收益,甚至會有細微的性能下降。
Fig 3. SigLiT、SigLIP和CLIP在batch size進行尺度放大情況下的0-shot性能對比。
超大的batch size是否需要更長的訓練量支持?作者在SigLiT的設置下,訓練了更長時間(也即是見了更多數據量),如Fig 4.所示,在超大batch size,如262k的情況下,對比較小batch size(如8k)提供更長的訓練時間的確能觀察到性能的較大提升。并且也能觀察到在超大batch size下,采用sigmoid和采用softmax的區別很小,但是在較小batch size(如8k)的情況下,則差距明顯。因此,在資源受限的情況下,采用SigLIP是很劃算的,所需資源少而且性能更強。同時,這個試驗也說明了,超大的batch size并不意味著訓練得更快,反而還需要更長的訓練時間。
Fig 4. 擴大了見過的數據量后,越大的batch size能帶來較為明顯的性能提升,同時,可以觀察到在超大batch size下,采用sigmoid和采用softmax的區別很小,但是在較小batch size(如8k)的情況下,則差距明顯。
除了batch size的影響外,作者還探索了很多有趣的點,包括SigLIP在多語言數據集上的表現、大尺度batch size下的訓練穩定性問題、訓練中負樣本比例的影響、sigmoid訓練的魯棒性探索等問題。在多語言數據集上,作者發現性能同樣在32k batch size上達到了飽和,其他細節就不累述了,感興趣的讀者自行翻閱。
筆者比較感興趣的是其他問題,比如在大尺度batch size下,訓練容易出現不穩定的情況,這個原因在于在訓練過程中,gradient norm會出現大幅度的尖峰,如Fig 5. 所示,這導致了參數和訓練損失的不穩定(也即是尖峰),作者觀察到,如果將Adam優化器的β2值從0.999下降到0.95,那么訓練過程就會穩定下來。
Fig 5. 大尺度batch size下訓練容易出現不穩定的情況。
從公式(2)中,注意到SigLIP是對所有正負樣本的pair進行計算損失然后累加求和的,這意味著可以從中剔除掉負樣本以控制負樣本的比例。對于batch size為|B|的一次損失計算而言,其中有|B|個正樣本,有|B|2−|B|個負樣本,負樣本其實在后期很多都是簡單負樣本,是否可以剔除簡單負樣本是一個值得探究的問題。作者提出了幾種消融試驗,去挑選負樣本,從而控制正負樣本比例:
- 隨機:隨機挑選負樣本對,并且對其進行剔除。
- 難負樣本:把難負樣本保留下來,即是通過將最高打分的topk負樣本保留下來。
- 簡單負樣本:把簡單負樣本保留下來,即是將打分最低的lowk負樣本保留下來。
- 難負樣本+對齊訓練量:保留難負樣本的同時,提高訓練step數量以對齊訓練數據量。
實驗結果如Fig 6.所示,其中的橫坐標為一個batch內的正樣本數量:負樣本數量
,其中的1:16k
則是不進行任何負樣本剔除的基線,從實驗中可以得出幾個結論:
- 只保留簡單負樣本,會使得模型性能完全崩潰。
- 隨機剔除負樣本,也會損失模型性能。
- 只保留難負樣本,對模型性能的損失是最小的,在對齊了訓練數據量后(因為剔除了負樣本,同個step下模型講過的數據對數量就少了,因此需要訓練更多step去對齊訓練數據量),性能甚至還能比基線更好,這說明了難負樣本才是最有價值的,怎么去合理地挑選難負樣本是提高模型性能的關鍵因素。
- 再看到隨著負樣本數量的減少,可學習偏置b的值和正負樣本的平均logit值都在遞增,這也是符合預期的。有趣的一點是,當采用難負例保留的策略中,隨著負樣本數量逐漸減少,正負例的logit區分度在加速減少,并且正例的logit變化基本上是平坦的,這個現象和隨機丟棄的策略是不同的。對此的解釋是,本來難負樣本和正樣本就比較接近,在減少了負樣本數量,只保留最難的負樣本后,負樣本的logit值就加速地上漲,從而導致了區分度減低的情況,這也是符合預期的。
Fig 6. 采用不同策略控制損失中的正負樣本比例的效果對比。
前文已經提到了sigmoid和softmax的區別在于,前者解耦了正負樣本的概率關系,這使得即便負樣本中出現假陰性樣本,也只會影響自己的損失,而不會影響到其他樣本,因此這帶來了數據的健壯性。作者也進行了對應的試驗,如Fig 7.所示,作者對數據中的圖片、文本進行隨機加噪、對batch內的圖文對進行隨機打亂、或者將上面的加噪方式都進行組合,發現基于sigmoid的訓練過程,總是比基于softmax的訓練過程更加魯棒。
Fig 7. 基于sigmoid的訓練能夠提高訓練的健壯性,對數據中的噪聲更為魯棒。
總的來說,SigLIP是一個很棒的工作,作者采用sigmoid損失去取代對比學習中的softmax函數,以更小的資源開銷帶來了更好的模型表現,目前也被很多多模態大模型所采用,作為視覺端的編碼器。
Reference
[1]. Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., ... & Sutskever, I. (2021, July). Learning transferable visual models from natural language supervision. In International Conference on Machine Learning (pp. 8748-8763). PMLR. aka CLIP
[2]. Zhai, Xiaohua, Basil Mustafa, Alexander Kolesnikov, and Lucas Beyer. "Sigmoid loss for language image pre-training." In Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 11975-11986. 2023. aka SigLIP
[3]. Jia, C., Yang, Y., Xia, Y., Chen, Y. T., Parekh, Z., Pham, H., ... & Duerig, T. (2021, July). Scaling up visual and vision-language representation learning with noisy text supervision. In International Conference on Machine Learning (pp. 4904-4916). PMLR. Short for ALIGN
[4]. Li, Y., Fan, H., Hu, R., Feichtenhofer, C., & He, K. (2022). Scaling Language-Image Pre-training via Masking. arXiv preprint arXiv:2212.00794. aka FLIP
[5]. Zhai, X., Wang, X., Mustafa, B., Steiner, A., Keysers, D., Kolesnikov, A., & Beyer, L. (2022). Lit: Zero-shot transfer with locked-image text tuning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 18123-18133). aka LiT