App下載

pytorch 多個(gè)反向傳播操作

宇宙一級(jí)潛在鴿王 2021-08-19 11:07:23 瀏覽數(shù) (2379)
反饋

之前小編的一篇文章pytorch 計(jì)算圖以及backward,講了一些pytorch中基本的反向傳播,理清了梯度是如何計(jì)算以及下降的,建議先看懂那個(gè),然后再看這個(gè)。

從一個(gè)錯(cuò)誤說(shuō)起:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed

在深度學(xué)習(xí)中,有些場(chǎng)景需要進(jìn)行兩次反向,比如Gan網(wǎng)絡(luò),需要對(duì)D進(jìn)行一次,還要對(duì)G進(jìn)行一次,很多人都會(huì)遇到上面這個(gè)錯(cuò)誤,這個(gè)錯(cuò)誤的意思就是嘗試對(duì)一個(gè)計(jì)算圖進(jìn)行第二次反向,但是計(jì)算圖已經(jīng)釋放了。

其實(shí)看簡(jiǎn)單點(diǎn)和我們之前的backward一樣,當(dāng)圖進(jìn)行了一次梯度更新,就會(huì)把一些梯度的緩存給清空,為了避免下次疊加,但在Gan這種情形下,我們必須要二次更新,那怎么辦呢。

有兩種方案:

方案一:

這是網(wǎng)上大多數(shù)給出的解決方案,在第一次反向時(shí)候加入一個(gè)l2.backward(),這樣就能避免釋放掉了。

方案二:

上面的方案雖然解決了問(wèn)題,但是并不優(yōu)美,因?yàn)槲覀冇肎an的時(shí)候,D和G兩者的更新并無(wú)聯(lián)系,二者的聯(lián)系僅僅是D里面用到了G的輸出,而這個(gè)輸出一般我們都是直接拿來(lái)用的,而問(wèn)題就出現(xiàn)在這里。

下面給一個(gè)模擬:

data = torch.randn(4,10)

model1 = torch.nn.Linear(10,2)
model2 = torch.nn.Linear(2,2)

optimizer1 = torch.optim.Adam(model1.parameters(), lr=0.001,betas=(0.5, 0.999))
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001,betas=(0.5, 0.999))

loss = torch.nn.CrossEntropyLoss()
data = torch.randn(4,10)
label = torch.Tensor([0,1,1,0]).long()
for i in range(20):
    a = model1(data)
    b = model2(a)
    l1 = loss(a,label)
    l2 = loss(b,label)
    optimizer2.zero_grad()
    l2.backward()
    optimizer2.step()

    optimizer1.zero_grad()
    l1.backward()
    optimizer1.step()

上面定義了兩個(gè)模型,而model2的輸入是model1的輸出,而更新的時(shí)候,二者都是各自更新自己的參數(shù),并無(wú)聯(lián)系,但是上面的代碼會(huì)報(bào)一個(gè)RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed 這樣的錯(cuò),解決方案可以是l2.backward(retain_graph=True)。

除此之外我們還可以是b = model2(a.detach()),這個(gè)就優(yōu)美一點(diǎn),a.detach()和a的區(qū)別你可以打印出來(lái)看一下,其實(shí)a.detach()是沒(méi)有梯度的,所以相當(dāng)于一個(gè)單純的數(shù)字,和model1就脫離了聯(lián)系,這樣model2和model1就是完全分離開(kāi)來(lái)的兩個(gè)圖,但是如果用的是a則model2和model1則仍然公用一個(gè)圖,所以導(dǎo)致了錯(cuò)誤。

可以看下面示意圖(這個(gè)是我猜測(cè),幫助理解):

2019-11-26_101938.jpg

左邊相當(dāng)于直接用a而右邊則用a.detach(),類似的在Gan網(wǎng)絡(luò)里面D的輸入可以改為G的輸出y_fake.detach()。

但有一點(diǎn)需要注意的是,兩個(gè)網(wǎng)絡(luò)一定沒(méi)有需要共同更新的 ,假如上面的optimizer2 = torch.optim.Adam(itertools.chain(model1.parameters(),model2.parameters()), lr=0.001,betas=(0.5, 0.999)),則還是用retain_graph=True保險(xiǎn),因?yàn)?detach則model2反向不會(huì)傳播到model1,導(dǎo)致不對(duì)model1里面參數(shù)更新。

補(bǔ)充:聊聊Focal Loss及其反向傳播

我們都知道,當(dāng)前的目標(biāo)檢測(cè)(Objece Detection)算法主要分為兩大類:two-stage detector和one-stage detector。two-stage detector主要包括rcnn、fast-rcnn、faster-rcnn和rfcn等,one-stage detector主要包括yolo和ssd等,前者精度高但檢測(cè)速度較慢,后者精度低些但速度很快。

對(duì)于two-stage detector而言,通常先由RPN生成proposals,再由RCNN對(duì)proposals進(jìn)行Classifcation和Bounding Box Regression。這樣做的一個(gè)好處是有利于樣本和模型之間的feature alignment,從而使Classification和Bounding Box Regression更容易些;此外,RPN和RCNN中存在正負(fù)樣本不均衡的問(wèn)題,RPN直接限制正負(fù)樣本的比例為1:1,對(duì)于固定的rpn_batch_size,正樣本不足的情況下才用負(fù)樣本來(lái)填充,RCNN則是直接限制了正負(fù)樣本的比例為1:3或者采用OHEM。

對(duì)于one-stage detector而言,樣本和模型之間的feature alignment只能通過(guò)reception field來(lái)實(shí)現(xiàn),且直接通過(guò)回歸方式進(jìn)行預(yù)測(cè),存在這嚴(yán)重的正負(fù)樣本數(shù)據(jù)不均衡(1:1000)的問(wèn)題,負(fù)樣本的比例過(guò)高,占據(jù)了loss的絕大部分,且大多數(shù)是容易分類的,這使得模型的訓(xùn)練朝著不希望的方向前進(jìn)。作者認(rèn)為這種數(shù)據(jù)的嚴(yán)重不均衡是造成one-stage detector精度低的主要原因,因此提出Focal Loss來(lái)解決這一問(wèn)題

通過(guò)人工控制正負(fù)樣本比例或者OHEM能夠一定程度解決數(shù)據(jù)不均衡問(wèn)題,但這兩種方法都比較粗暴,采用這種“一刀切”的方式有可能把一些hard examples忽略掉。因此,作者提出了一種新的損失函數(shù)Focal Loss,不忽略任何樣本,同時(shí)又能讓模型訓(xùn)練時(shí)更加專注在hard examples上。簡(jiǎn)單說(shuō)明下Focal loss的原理

Focal Loss是在標(biāo)準(zhǔn)的交叉熵?fù)p失的基礎(chǔ)上改進(jìn)而來(lái)。以二分類為例,標(biāo)準(zhǔn)的交叉熵?fù)p失函數(shù)為

針對(duì)類別不均衡,針對(duì)對(duì)不同類別對(duì)loss的貢獻(xiàn)進(jìn)行控制即可,也就是加一個(gè)控制權(quán)重αt,那么改進(jìn)后的balanced cross entropy loss為

但是balanced cross entropy loss沒(méi)辦法讓訓(xùn)練時(shí)專注在hard examples上。實(shí)際上,樣本的正確分類概率pt越大,那么往往說(shuō)明這個(gè)樣本越易分。所以,最終的Focal Loss為

Focal Loss存在這兩個(gè)超參數(shù)(hyperparameter),不同的αt和γ,對(duì)于的loss如Figure 1所示。從Figure 4, 我們可以看到γ的變化對(duì)正(forground)樣本的累積誤差的影響并不大,但是對(duì)于負(fù)(background)樣本的累積誤差的影響還是很大的(γ=2時(shí),將近99%的background樣本的損失都非常?。?/p>

接下來(lái)看下實(shí)驗(yàn)結(jié)果,為了驗(yàn)證Focal Loss,作者提出了一種新的one-stage detector架構(gòu)RetinaNet,采用的是resnet_fpn,同時(shí)scales增加到15個(gè),如Figure 3所示

Table 1給出了RetinaNet和Focal Loss的一些實(shí)驗(yàn)結(jié)果,從中我們看出增加α-類別均衡,AP提高了0.9,再增加了γ控制,AP達(dá)到了37.8.Focal Local相比于OHEM,AP提高了3.2。從Table 2可以看出,增加訓(xùn)練時(shí)間并采用scale jitter,AP最終那達(dá)到39.1。

Focal Loss的原理分析和實(shí)驗(yàn)結(jié)果至此結(jié)束了,那么,我們接下來(lái)看下Focal Loss的反向傳播。首先給出Softmax Activation的反向梯度傳播公式,為

有了Softmax Activation的反向梯度傳播公式,根據(jù)鏈?zhǔn)椒▌t,F(xiàn)ocal Loss的反向梯度傳播公式為

總結(jié):

Focal Loss主要用于解決數(shù)據(jù)不均衡問(wèn)題,可以看做是OHEM算法的延伸。作者是將Focal Loss用于one-stage detector,但實(shí)際上這種解決數(shù)據(jù)不均衡的方法對(duì)于two-stage detector來(lái)講同樣有效。

以上就是pytorch中基本的反向傳播的全部?jī)?nèi)容,希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。


0 人點(diǎn)贊