0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

利用PyTorch实现NeRF代码详解

3D视觉工坊 来源:3DCV 2023-10-21 09:46 次阅读

作者:大森林| 来源:3DCV

1. NeRF定义

神经辐射场(NeRF)是一种利用神经网络来表示和渲染复杂的三维场景的方法。它可以从一组二维图片中学习出一个连续的三维函数,这个函数可以给出空间中任意位置和方向上的颜色和密度。通过体积渲染的技术,NeRF可以从任意视角合成出逼真的图像,包括透明和半透明物体,以及复杂的光线传播效果。

2. NeRF优势

NeRF模型相比于其他新的视图合成和场景表示方法有以下几个优势:

1)NeRF不需要离散化的三维表示,如网格或体素,因此可以避免模型精度和细节程度受到限制。NeRF也可以自适应地处理不同形状和大小的场景,而不需要人工调整参数

2)NeRF使用位置编码的方式将位置和角度信息映射到高频域,使得网络能够更好地捕捉场景的细微结构和变化。NeRF还使用视角相关的颜色预测,能够生成不同视角下不同的光照效果。

3)NeRF使用分段随机采样的方式来近似体积渲染的积分,这样可以保证采样位置的连续性,同时避免网络过拟合于离散点的信息。NeRF还使用多层级体素采样的技巧,以提高渲染效率和质量。

3. NeRF实现步骤

1)定义一个全连接的神经网络,它的输入是空间位置和视角方向,输出是颜色和密度。

2)使用位置编码的方式将输入映射到高频域,以便网络能够捕捉细微的结构和变化。

3)使用分段随机采样的方式从每条光线上采样一些点,然后用神经网络预测这些点的颜色和密度。

4)使用体积渲染的公式计算每条光线上的颜色和透明度,作为最终的图像输出。

5)使用渲染损失函数来优化神经网络的参数,使得渲染的图像与输入的图像尽可能接近。

importtorch
importtorch.nnasnn
importtorch.nn.functionalasF

#定义一个全连接的神经网络,它的输入是空间位置和视角方向,输出是颜色和密度。
classNeRF(nn.Module):
def__init__(self,D=8,W=256,input_ch=3,input_ch_views=3,output_ch=4,skips=[4]):
super().__init__()
#定义位置编码后的位置信息的线性层,如果层数在skips列表中,则将原始位置信息与隐藏层拼接
self.pts_linears=nn.ModuleList(
[nn.Linear(input_ch,W)]+[nn.Linear(W,W)ifinotinskipselsenn.Linear(W+input_ch,W)foriinrange(D-1)])
#定义位置编码后的视角方向信息的线性层
self.views_linears=nn.ModuleList([nn.Linear(W+input_ch_views,W//2)]+[nn.Linear(W//2,W//2)foriinrange(1)])
#定义特征向量的线性层
self.feature_linear=nn.Linear(W//2,W)
#定义透明度(alpha)值的线性层
self.alpha_linear=nn.Linear(W,1)
#定义RGB颜色的线性层
self.rgb_linear=nn.Linear(W+input_ch_views,3)

defforward(self,x):
#x:(B,input_ch+input_ch_views)
#提取位置和视角方向信息
p=x[:,:3]#(B,3)
d=x[:,3:]#(B,3)

#对输入进行位置编码,将低频信号映射到高频域
p=positional_encoding(p)#(B,input_ch)
d=positional_encoding(d)#(B,input_ch_views)

#将位置信息输入网络
h=p
fori,linenumerate(self.pts_linears):
h=l(h)
h=F.relu(h)
ifiinskips:
h=torch.cat([h,p],-1)#如果层数在skips列表中,则将原始位置信息与隐藏层拼接

#将视角方向信息与隐藏层拼接,并输入网络
h=torch.cat([h,d],-1)
fori,linenumerate(self.views_linears):
h=l(h)
h=F.relu(h)

#预测特征向量和透明度(alpha)值
feature=self.feature_linear(h)#(B,W)
alpha=self.alpha_linear(feature)#(B,1)

#使用特征向量和视角方向信息预测RGB颜色
rgb=torch.cat([feature,d],-1)
rgb=self.rgb_linear(rgb)#(B,3)

returntorch.cat([rgb,alpha],-1)#(B,4)

#定义位置编码函数
defpositional_encoding(x):
#x:(B,C)
B,C=x.shape
L=int(C//2)#计算位置编码的长度
freqs=torch.logspace(0.,L-1,steps=L).to(x.device)*math.pi#计算频率系数,呈指数增长
freqs=freqs[None].repeat(B,1)#(B,L)
x_pos_enc_low=torch.sin(x[:,:L]*freqs)#对前一半的输入进行正弦变换,得到低频部分(B,L)
x_pos_enc_high=torch.cos(x[:,:L]*freqs)#对前一半的输入进行余弦变换,得到高频部分(B,L)
x_pos_enc=torch.cat([x_pos_enc_low,x_pos_enc_high],dim=-1)#将低频和高频部分拼接,得到位置编码后的输入(B,C)
returnx_pos_enc

#定义体积渲染函数
defvolume_rendering(rays_o,rays_d,model):
#rays_o:(B,3),每条光线的起点
#rays_d:(B,3),每条光线的方向
B=rays_o.shape[0]

#在每条光线上采样一些点
near,far=0.,1.#近平面和远平面
N_samples=64#每条光线的采样数
t_vals=torch.linspace(near,far,N_samples).to(rays_o.device)#(N_samples,)
t_vals=t_vals.expand(B,N_samples)#(B,N_samples)
z_vals=near*(1.-t_vals)+far*t_vals#计算每个采样点的深度值(B,N_samples)
z_vals=z_vals.unsqueeze(-1)#(B,N_samples,1)
pts=rays_o.unsqueeze(1)+rays_d.unsqueeze(1)*z_vals#计算每个采样点的空间位置(B,N_samples,3)

#将采样点和视角方向输入网络
pts_flat=pts.reshape(-1,3)#(B*N_samples,3)
rays_d_flat=rays_d.unsqueeze(1).expand(-1,N_samples,-1).reshape(-1,3)#(B*N_samples,3)
x_flat=torch.cat([pts_flat,rays_d_flat],-1)#(B*N_samples,6)
y_flat=model(x_flat)#(B*N_samples,4)
y=y_flat.reshape(B,N_samples,4)#(B,N_samples,4)

#提取RGB颜色和透明度(alpha)值
rgb=y[...,:3]#(B,N_samples,3)
alpha=y[...,3]#(B,N_samples)

#计算每个采样点的权重
dists=torch.cat([z_vals[...,1:]-z_vals[...,:-1],torch.tensor([1e10]).to(z_vals.device).expand(B,1)],-1)#计算相邻采样点之间的距离,最后一个距离设为很大的值(B,N_samples)
alpha=1.-torch.exp(-alpha*dists)#计算每个采样点的不透明度,即1减去透明度的指数衰减(B,N_samples)
weights=alpha*torch.cumprod(torch.cat([torch.ones((B,1)).to(alpha.device),1.-alpha+1e-10],-1),-1)[:,:-1]#计算每个采样点的权重,即不透明度乘以之前所有采样点的透明度累积积,最后一个权重设为0(B,N_samples)

#计算每条光线的最终颜色和透明度
rgb_map=torch.sum(weights.unsqueeze(-1)*rgb,-2)#加权平均每个采样点的RGB颜色,得到每条光线的颜色(B,3)
depth_map=torch.sum(weights*z_vals.squeeze(-1),-1)#加权平均每个采样点的深度值,得到每条光线的深度(B,)
acc_map=torch.sum(weights,-1)#累加每个采样点的权重,得到每条光线的不透明度(B,)

returnrgb_map,depth_map,acc_map

#定义渲染损失函数
defrendering_loss(rgb_map_pred,rgb_map_gt):
return((rgb_map_pred-rgb_map_gt)**2).mean()#计算预测的颜色与真实颜色之间的均方误差

综上所述,本代码实现了NeRF的核心结构,具体实现内容包括以下四个部分。

1)定义了NeRF网络结构,包含位置编码和多层全连接网络,输入是位置和视角,输出是颜色和密度。

2)实现了位置编码函数,通过正弦和余弦变换引入高频信息。

3)实现了体积渲染函数,在光线上采样点,查询NeRF网络预测颜色和密度,然后通过加权平均实现整体渲染。

4)定义了渲染损失函数,计算预测颜色和真实颜色的均方误差。

当然,本方案只是实现NeRF的一个基础方案,更多的细节还需要进行优化。

当然,为了方便下载,我们已经将上述两个源代码打包好了。

审核编辑:汤梓红

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4762

    浏览量

    100520
  • 函数
    +关注

    关注

    3

    文章

    4304

    浏览量

    62412
  • 代码
    +关注

    关注

    30

    文章

    4741

    浏览量

    68326
  • pytorch
    +关注

    关注

    2

    文章

    803

    浏览量

    13142

原文标题:一文带你入门NeRF:利用PyTorch实现NeRF代码详解(附代码)

文章出处:【微信号:3D视觉工坊,微信公众号:3D视觉工坊】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    Image Style Transfer pytorch方式实现的主要思路

    深度学总结:Image Style Transfer pytorch方式实现,这个是非基于autoencoder和domain adversrial方式
    发表于 06-20 10:58

    PyTorch如何入门

    PyTorch 入门实战(一)——Tensor
    发表于 06-01 09:58

    Pytorch代码移植嵌入式开发笔记,错过绝对后悔

    @[TOC]Pytorch 代码移植嵌入式开发笔记目前在做开发完成后的AI模型移植到前端的工作。 由于硬件设施简陋,需要把代码和算法翻译成基础加乘算法并输出每个环节参数。记录几点实用技巧以及项目
    发表于 11-08 08:24

    单片机点灯的基本语法代码详解

    【单片机】点灯基本语法代码详解代码详解#include #include //功能:实现P1口左移#define uchar unsigne
    发表于 02-16 06:34

    PyTorch官网教程PyTorch深度学习:60分钟快速入门中文翻译版

    PyTorch 深度学习:60分钟快速入门”为 PyTorch 官网教程,网上已经有部分翻译作品,随着PyTorch1.0 版本的公布,这个教程有较大的代码改动,本人对教程进行重新翻
    的头像 发表于 01-13 11:53 1w次阅读

    Pytorch 1.1.0,来了!

    许多用户已经转向使用标准PyTorch运算符编写自定义实现,但是这样的代码遭受高开销:大多数PyTorch操作在GPU上启动至少一个内核,并且RNN由于其重复性质通常运行许多操作。但是
    的头像 发表于 05-05 10:02 5887次阅读
    <b class='flag-5'>Pytorch</b> 1.1.0,来了!

    详解Tutorial代码的学习过程与准备

    导读:本文主要解析Pytorch Tutorial中BiLSTM_CRF代码,几乎注释了每行代码,希望本文能够帮助大家理解这个tutorial,除此之外借助代码和图解也对理解条件随机场
    的头像 发表于 04-03 16:50 1837次阅读
    <b class='flag-5'>详解</b>Tutorial<b class='flag-5'>代码</b>的学习过程与准备

    Pytorch实现MNIST手写数字识别

    Pytorch 实现MNIST手写数字识别
    发表于 06-16 14:47 7次下载

    pytorch实现断电继续训练时需要注意的要点

    本文整理了pytorch实现断电继续训练时需要注意的要点,附有代码详解
    的头像 发表于 08-22 09:50 1385次阅读

    PyTorch教程3.2之面向对象的设计实现

    电子发烧友网站提供《PyTorch教程3.2之面向对象的设计实现.pdf》资料免费下载
    发表于 06-05 15:48 0次下载
    <b class='flag-5'>PyTorch</b>教程3.2之面向对象的设计<b class='flag-5'>实现</b>

    PyTorch教程3.5之线性回归的简洁实现

    电子发烧友网站提供《PyTorch教程3.5之线性回归的简洁实现.pdf》资料免费下载
    发表于 06-05 11:28 0次下载
    <b class='flag-5'>PyTorch</b>教程3.5之线性回归的简洁<b class='flag-5'>实现</b>

    PyTorch教程13.6之多个GPU的简洁实现

    电子发烧友网站提供《PyTorch教程13.6之多个GPU的简洁实现.pdf》资料免费下载
    发表于 06-05 14:21 0次下载
    <b class='flag-5'>PyTorch</b>教程13.6之多个GPU的简洁<b class='flag-5'>实现</b>

    [源代码]Python算法详解

    [源代码]Python算法详解[源代码]Python算法详解
    发表于 06-06 17:50 0次下载

    TorchFix:基于PyTorch代码静态分析

    TorchFix是我们最近开发的一个新工具,旨在帮助PyTorch用户维护健康的代码库并遵循PyTorch的最佳实践。首先,我想要展示一些我们努力解决的问题的示例。
    的头像 发表于 12-18 15:20 1034次阅读

    全面总结动态NeRF

    1. 摘要 神经辐射场(NeRF)是一种新颖的隐式方法,可以实现高分辨率的三维重建和表示。在首次提出NeRF的研究之后,NeRF获得了强大的发展力量,并在三维建模、表示和重建领域蓬勃发
    的头像 发表于 11-14 16:48 144次阅读
    全面总结动态<b class='flag-5'>NeRF</b>