【TF/Guide筆記】 09. Distributed training
????感覺toc里安排的順序有點怪,我就跳著看了,直接來到最關(guān)鍵的分布式計算部分。
????MirroredStrategy
????適合單機(jī)多卡,原理是所有GPU上拷貝所有Variable,并且保持這些Variable始終是相同的,訓(xùn)練的時候每個卡分配一部分?jǐn)?shù)據(jù)。
? ? 代碼的嵌套太繞了,說實在的看不懂,只能推測是怎么實現(xiàn)的。所謂的保證Variable始終同步,其實是保證了每次對Variable的更新是相同的,這是文檔原話,但相同的是什么卻沒說,顯然這個值應(yīng)該是梯度,所以MirroredStrategy做的隱式假設(shè)是,你在跑的算法得是基于梯度下降的(雖然ML應(yīng)該都是)。
? ? 梯度下降里的梯度是可加的,大概只有這一步保證可以allreduce,因為它明確說了allreduce里使用的是加法。所以Strategy應(yīng)該是識別了代碼中tape.gradient這一步,對這個接口產(chǎn)出的結(jié)果進(jìn)行allreduce,mirrored就實現(xiàn)了,原理上應(yīng)該也是對的。所以這東西看似很通用的包一下就行了,但是里面的限制應(yīng)該還是蠻多的,或者說你也可以把MirroredStrategy當(dāng)mpi用,但它內(nèi)部對tape.gradient做了特化操作。
? ? 至于所謂基于NCCL的高效allreduce,我在源碼的注釋里找到了一句描述,他的實現(xiàn)方法就是把值都拷貝到一個設(shè)備上,計算完事兒后再廣播回去。對于單機(jī)多卡來說這種方法的確是很夠用了,畢竟誰會在一臺機(jī)器上插它10張GPU呢。當(dāng)然不排除我這代碼看叉了。
? ? 感覺MirroredStrategy適合計算任務(wù)很重的模型,因為他的前提是一個device必須得裝得下全部模型。
????TPUStrategy
????懶得看了,咱也不大可能買得到這玩意。
????MultiWorkerMirroredStrategy
????多機(jī)多卡,原理上與單機(jī)多卡相同,都是保證每個device上的Variable同步。多了一個選項是配置多機(jī)間怎么通信,除了NCCL之外提供了用rpc跑ring。
? ? 以前我們的allreduce實現(xiàn)的是樹狀的,這里并沒有這個選項,我估計tf還是假設(shè)了你不會搞那么多機(jī)器,也許是機(jī)器數(shù)太多帶來的overhead不如少放幾臺讓他慢慢跑。
? ? 在另一篇引用的文檔里明確說了,如果一個device掛了,整個程序就會gg,tf沒有自動的災(zāi)備,但你可以通過checkpoint來自己實現(xiàn),比如每跑一輪checkpoint一下,如果掛了就從最近的恢復(fù)。但是文檔里還說這種方法將來有其他接口實現(xiàn),因為讓用戶自己寫的話,他可能無法保證每個機(jī)器上的checkpoint是同一次dump下去的,他提供的BackupAndRestore幫用戶管理了模型版本。
????其實MirroredStrategy就是Syncronize training,以前我們最苦惱的就是sync模式的訓(xùn)練速度,換成tf的這個思路來看的話,如果把所有參數(shù)都放在本地存replica,然后數(shù)據(jù)并行的話,那么每一個batch有且僅有一次all_reduce就可以了,不再需要其他任何網(wǎng)絡(luò)開銷,應(yīng)該會比我們原來寫的快很多。
????ParameterServerStrategy
????文檔里說,默認(rèn)情況下,ps strategy里的worker之間是不同步的,所以ps模式就是asyncronize training。
????在看Mirrored時沒注意到一個問題,似乎需要用戶自己管理輸入文件,來保證每個worker拿到的數(shù)據(jù)是沒有交集的,也許dataset做了特殊操作,但strategy這里是沒管的。
????ps模式更類似于spark,需要有個chief節(jié)點負(fù)責(zé)分發(fā)任務(wù),而這個任務(wù)是以batch為單位的。首先coordinator.create_per_worker_dataset相當(dāng)于coordinator告訴了每個worker都自己先去把數(shù)據(jù)讀好,拿回來的iterator應(yīng)該是個類似Variable的東西,分發(fā)任務(wù)的時候再作為參數(shù)帶回給每個worker,所以框架針對任務(wù)的執(zhí)行是不用管數(shù)據(jù)流的,函數(shù)內(nèi)部要自己調(diào)用next()。
????根據(jù)文檔的說明,他不希望你依賴OutOfRangeError,也就是數(shù)據(jù)被讀完了,這時候step function會失敗退出,也許外面需要寫while (coordinator.schedule().is_ok())之類的才行。tf推薦的是你確定的執(zhí)行多少個step來結(jié)束訓(xùn)練,并且最好你的數(shù)據(jù)是可以無限iterate的。
????這樣的好處大概是你的訓(xùn)練時長與數(shù)據(jù)大小無關(guān),如果數(shù)據(jù)太小,也許每個epoch實際執(zhí)行了1.5次數(shù)據(jù),數(shù)據(jù)太大也不至于跑不完。但一個epoch不嚴(yán)格等于一份數(shù)據(jù)對我來說有點兒反常識了。
????雖然默認(rèn)下,ps模式會把分散在ps上的所有Variable shard合成一個tensor再給到運算,讓下游無感知的使用,但他也提供了特化的函數(shù),例如embedding_lookup,讓你可以只獲取需要的Variable shard,離散模型大概可以用這套接口實現(xiàn)。
????step_fn返回的這個loss,應(yīng)該是個類似Variable的某種handler,因為它的Variable涉及同步問題,所以不可能直接用Variable,應(yīng)該是strategy內(nèi)部給了某種專門用來reduce的接口。文檔里特意又調(diào)用了一次schedule來拿到這個handler,然后才loss.fetch(),不得不說有點奇怪。
????tf不支持batch級的callback函數(shù),這東西也確實不合理。
????文檔說tf的數(shù)據(jù)必須在create_per_worker_dataset全部讀完,不過從框架并沒有對輸入做假設(shè)來看,其實你完全可以自己啟一個數(shù)據(jù)流,實現(xiàn)一個基于rpc的iterator。
????CentralStorageStrategy
????與mirrored類似,適用于單機(jī)多卡,不過Variable統(tǒng)一存放在CPU內(nèi)存上,每次拷貝到GPU上做操作。目前處于試驗階段,我也想不出有什么適用場景。
????DefaultStrategy
????一個相當(dāng)于啥也不干的殼子,主要用于測試代碼。例如你不必一級一級往下傳strategy,可以通過tf.distributed.get_strategy獲取當(dāng)前scope的strategy,用來操作reduce之類的,但在本地測試的時候可以先不聲明真正的strategy,這時候get_strategy返回的就是Default Strategy,它是一個全局唯一的實例,不能自己聲明。
????OneDeviceStrategy
????顧名思義只用一個device,與default差不多,不過會在run之前把聲明在別處的Variable先拷貝到目標(biāo)device上去。