一文帶你入門NeRF:利用PyTorch實(shí)現(xiàn)NeRF代碼詳解(附代碼)
作者:大森林?| 來源:3DCV
在公眾號「3DCV」后臺,回復(fù)「原論文」即可獲取代碼。
添加微信:dddvisiona,備注:NeRF,拉你入群。文末附行業(yè)細(xì)分群。
神經(jīng)輻射場(NeRF)是一種利用神經(jīng)網(wǎng)絡(luò)來表示和渲染復(fù)雜的三維場景的方法。它可以從一組二維圖片中學(xué)習(xí)出一個連續(xù)的三維函數(shù),這個函數(shù)可以給出空間中任意位置和方向上的顏色和密度。通過體積渲染的技術(shù),NeRF可以從任意視角合成出逼真的圖像,包括透明和半透明物體,以及復(fù)雜的光線傳播效果。
NeRF模型相比于其他新的視圖合成和場景表示方法有以下幾個優(yōu)勢:
1)NeRF不需要離散化的三維表示,如網(wǎng)格或體素,因此可以避免模型精度和細(xì)節(jié)程度受到限制。NeRF也可以自適應(yīng)地處理不同形狀和大小的場景,而不需要人工調(diào)整參數(shù)。
2)NeRF使用位置編碼的方式將位置和角度信息映射到高頻域,使得網(wǎng)絡(luò)能夠更好地捕捉場景的細(xì)微結(jié)構(gòu)和變化。NeRF還使用視角相關(guān)的顏色預(yù)測,能夠生成不同視角下不同的光照效果。
3)NeRF使用分段隨機(jī)采樣的方式來近似體積渲染的積分,這樣可以保證采樣位置的連續(xù)性,同時避免網(wǎng)絡(luò)過擬合于離散點(diǎn)的信息。NeRF還使用多層級體素采樣的技巧,以提高渲染效率和質(zhì)量。
1)定義一個全連接的神經(jīng)網(wǎng)絡(luò),它的輸入是空間位置和視角方向,輸出是顏色和密度。
2)使用位置編碼的方式將輸入映射到高頻域,以便網(wǎng)絡(luò)能夠捕捉細(xì)微的結(jié)構(gòu)和變化。
3)使用分段隨機(jī)采樣的方式從每條光線上采樣一些點(diǎn),然后用神經(jīng)網(wǎng)絡(luò)預(yù)測這些點(diǎn)的顏色和密度。
4)使用體積渲染的公式計算每條光線上的顏色和透明度,作為最終的圖像輸出。
5)使用渲染損失函數(shù)來優(yōu)化神經(jīng)網(wǎng)絡(luò)的參數(shù),使得渲染的圖像與輸入的圖像盡可能接近。
import?torchimport?torch.nn?as?nnimport?torch.nn.functional?as?F#?定義一個全連接的神經(jīng)網(wǎng)絡(luò),它的輸入是空間位置和視角方向,輸出是顏色和密度。class?NeRF(nn.Module):????def?__init__(self,?D=8,?W=256,?input_ch=3,?input_ch_views=3,?output_ch=4,?skips=[4]):????????super().__init__()????????#?定義位置編碼后的位置信息的線性層,如果層數(shù)在skips列表中,則將原始位置信息與隱藏層拼接????????self.pts_linears?=?nn.ModuleList(????????????[nn.Linear(input_ch,?W)]?+?[nn.Linear(W,?W)?if?i?not?in?skips?else?nn.Linear(W?+?input_ch,?W)?for?i?in?range(D-1)])????????#?定義位置編碼后的視角方向信息的線性層????????self.views_linears?=?nn.ModuleList([nn.Linear(W?+?input_ch_views,?W//2)]?+?[nn.Linear(W//2,?W//2)?for?i?in?range(1)])????????#?定義特征向量的線性層????????self.feature_linear?=?nn.Linear(W//2,?W)????????#?定義透明度(alpha)值的線性層????????self.alpha_linear?=?nn.Linear(W,?1)????????#?定義RGB顏色的線性層????????self.rgb_linear?=?nn.Linear(W?+?input_ch_views,?3)????def?forward(self,?x):????????#?x:?(B,?input_ch?+?input_ch_views)????????#?提取位置和視角方向信息????????p?=?x[:,?:3]?#?(B,?3)????????d?=?x[:,?3:]?#?(B,?3)????????#?對輸入進(jìn)行位置編碼,將低頻信號映射到高頻域????????p?=?positional_encoding(p)?#?(B,?input_ch)????????d?=?positional_encoding(d)?#?(B,?input_ch_views)????????#?將位置信息輸入網(wǎng)絡(luò)????????h?=?p????????for?i,?l?in?enumerate(self.pts_linears):????????????h?=?l(h)????????????h?=?F.relu(h)????????????if?i?in?skips:????????????????h?=?torch.cat([h,?p],?-1)?#?如果層數(shù)在skips列表中,則將原始位置信息與隱藏層拼接????????#?將視角方向信息與隱藏層拼接,并輸入網(wǎng)絡(luò)????????h?=?torch.cat([h,?d],?-1)????????for?i,?l?in?enumerate(self.views_linears):????????????h?=?l(h)????????????h?=?F.relu(h)????????#?預(yù)測特征向量和透明度(alpha)值????????feature?=?self.feature_linear(h)?#?(B,?W)????????alpha?=?self.alpha_linear(feature)?#?(B,?1)????????????????#?使用特征向量和視角方向信息預(yù)測RGB顏色????????rgb?=?torch.cat([feature,?d],?-1)?????????rgb?=?self.rgb_linear(rgb)?#?(B,?3)????????return?torch.cat([rgb,?alpha],?-1)?#?(B,?4)#?定義位置編碼函數(shù)def?positional_encoding(x):????#?x:?(B,?C)????B,?C?=?x.shape????L?=?int(C?//?2)?#?計算位置編碼的長度????freqs?=?torch.logspace(0.,?L?-?1,?steps=L).to(x.device)?*?math.pi?#?計算頻率系數(shù),呈指數(shù)增長????freqs?=?freqs[None].repeat(B,?1)?#?(B,?L)????x_pos_enc_low?=?torch.sin(x[:,?:L]?*?freqs)?#?對前一半的輸入進(jìn)行正弦變換,得到低頻部分?(B,?L)????x_pos_enc_high?=?torch.cos(x[:,?:L]?*?freqs)?#?對前一半的輸入進(jìn)行余弦變換,得到高頻部分?(B,?L)????x_pos_enc?=?torch.cat([x_pos_enc_low,?x_pos_enc_high],?dim=-1)?#?將低頻和高頻部分拼接,得到位置編碼后的輸入?(B,?C)????return?x_pos_enc#?定義體積渲染函數(shù)def?volume_rendering(rays_o,?rays_d,?model):????#?rays_o:?(B,?3),?每條光線的起點(diǎn)????#?rays_d:?(B,?3),?每條光線的方向????B?=?rays_o.shape[0]????#?在每條光線上采樣一些點(diǎn)????near,?far?=?0.,?1.?#?近平面和遠(yuǎn)平面????N_samples?=?64?#?每條光線的采樣數(shù)????t_vals?=?torch.linspace(near,?far,?N_samples).to(rays_o.device)?#?(N_samples,)????t_vals?=?t_vals.expand(B,?N_samples)?#?(B,?N_samples)????z_vals?=?near?*?(1.?-?t_vals)?+?far?*?t_vals?#?計算每個采樣點(diǎn)的深度值?(B,?N_samples)????z_vals?=?z_vals.unsqueeze(-1)?#?(B,?N_samples,?1)????pts?=?rays_o.unsqueeze(1)?+?rays_d.unsqueeze(1)?*?z_vals?#?計算每個采樣點(diǎn)的空間位置?(B,?N_samples,?3)????#?將采樣點(diǎn)和視角方向輸入網(wǎng)絡(luò)????pts_flat?=?pts.reshape(-1,?3)?#?(B*N_samples,?3)????rays_d_flat?=?rays_d.unsqueeze(1).expand(-1,?N_samples,?-1).reshape(-1,?3)?#?(B*N_samples,?3)????x_flat?=?torch.cat([pts_flat,?rays_d_flat],?-1)?#?(B*N_samples,?6)????y_flat?=?model(x_flat)?#?(B*N_samples,?4)????y?=?y_flat.reshape(B,?N_samples,?4)?#?(B,?N_samples,?4)????#?提取RGB顏色和透明度(alpha)值????rgb?=?y[...,?:3]?#?(B,?N_samples,?3)????alpha?=?y[...,?3]?#?(B,?N_samples)????#?計算每個采樣點(diǎn)的權(quán)重????dists?=?torch.cat([z_vals[...,?1:]?-?z_vals[...,?:-1],?torch.tensor([1e10]).to(z_vals.device).expand(B,?1)],?-1)?#?計算相鄰采樣點(diǎn)之間的距離,最后一個距離設(shè)為很大的值?(B,?N_samples)????alpha?=?1.?-?torch.exp(-alpha?*?dists)?#?計算每個采樣點(diǎn)的不透明度,即1減去透明度的指數(shù)衰減?(B,?N_samples)????weights?=?alpha?*?torch.cumprod(torch.cat([torch.ones((B,?1)).to(alpha.device),?1.?-?alpha?+?1e-10],?-1),?-1)[:,?:-1]?#?計算每個采樣點(diǎn)的權(quán)重,即不透明度乘以之前所有采樣點(diǎn)的透明度累積積,最后一個權(quán)重設(shè)為0?(B,?N_samples)????#?計算每條光線的最終顏色和透明度????rgb_map?=?torch.sum(weights.unsqueeze(-1)?*?rgb,?-2)?#?加權(quán)平均每個采樣點(diǎn)的RGB顏色,得到每條光線的顏色?(B,?3)????depth_map?=?torch.sum(weights?*?z_vals.squeeze(-1),?-1)?#?加權(quán)平均每個采樣點(diǎn)的深度值,得到每條光線的深度?(B,)????acc_map?=?torch.sum(weights,?-1)?#?累加每個采樣點(diǎn)的權(quán)重,得到每條光線的不透明度?(B,)????????return?rgb_map,?depth_map,?acc_map#?定義渲染損失函數(shù)def?rendering_loss(rgb_map_pred,?rgb_map_gt):????return?((rgb_map_pred?-?rgb_map_gt)**2).mean()?#?計算預(yù)測的顏色與真實(shí)顏色之間的均方誤差
綜上所述,本代碼實(shí)現(xiàn)了NeRF的核心結(jié)構(gòu),具體實(shí)現(xiàn)內(nèi)容包括以下四個部分。
1)定義了NeRF網(wǎng)絡(luò)結(jié)構(gòu),包含位置編碼和多層全連接網(wǎng)絡(luò),輸入是位置和視角,輸出是顏色和密度。
2)實(shí)現(xiàn)了位置編碼函數(shù),通過正弦和余弦變換引入高頻信息。
3)實(shí)現(xiàn)了體積渲染函數(shù),在光線上采樣點(diǎn),查詢NeRF網(wǎng)絡(luò)預(yù)測顏色和密度,然后通過加權(quán)平均實(shí)現(xiàn)整體渲染。
4)定義了渲染損失函數(shù),計算預(yù)測顏色和真實(shí)顏色的均方誤差。
當(dāng)然,本方案只是實(shí)現(xiàn)NeRF的一個基礎(chǔ)方案,更多的細(xì)節(jié)還需要進(jìn)行優(yōu)化。需要完整學(xué)習(xí)代碼的同學(xué)可以通過下面兩個鏈接獲取:
原論文及代碼(NeRF: Neural Radiance Fields):https://github.com/bmild/nerf
大佬實(shí)現(xiàn)的pytorch版本(NeRF-pytorch):https://github.com/yenchenlin/nerf-pytorch
當(dāng)然,為了方便下載,我們已經(jīng)將上述兩個源代碼打包好了,請關(guān)注“3D視覺工坊公眾號”,回復(fù):原論文,獲取完整詳細(xì)代碼。