电子发烧友App

硬声App

0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示
创作
电子发烧友网>电子资料下载>电子资料>PyTorch教程6.2之参数管理

PyTorch教程6.2之参数管理

2023-06-05 | pdf | 0.13 MB | 次下载 | 免费

资料介绍

一旦我们选择了一个架构并设置了我们的超参数,我们就进入训练循环,我们的目标是找到最小化损失函数的参数值。训练后,我们将需要这些参数来进行未来的预测。此外,我们有时会希望提取参数以在其他上下文中重用它们,将我们的模型保存到磁盘以便它可以在其他软件中执行,或者进行检查以期获得科学理解。

大多数时候,我们将能够忽略参数声明和操作的具体细节,依靠深度学习框架来完成繁重的工作。然而,当我们远离具有标准层的堆叠架构时,我们有时需要陷入声明和操作参数的困境。在本节中,我们将介绍以下内容:

  • 访问用于调试、诊断和可视化的参数。

  • 跨不同模型组件共享参数。

import torch
from torch import nn
from mxnet import init, np, npx
from mxnet.gluon import nn

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf

我们首先关注具有一个隐藏层的 MLP。

net = nn.Sequential(nn.LazyLinear(8),
          nn.ReLU(),
          nn.LazyLinear(1))

X = torch.rand(size=(2, 4))
net(X).shape
torch.Size([2, 1])
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize() # Use the default initialization method

X = np.random.uniform(size=(2, 4))
net(X).shape
(2, 1)
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])

X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 1)
net = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(4, activation=tf.nn.relu),
  tf.keras.layers.Dense(1),
])

X = tf.random.uniform((2, 4))
net(X).shape
TensorShape([2, 1])

6.2.1. 参数访问

让我们从如何从您已知的模型中访问参数开始。

当通过类定义模型时Sequential,我们可以首先通过索引模型来访问任何层,就好像它是一个列表一样。每个层的参数都方便地位于其属性中。

When a model is defined via the Sequential class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.

Flax and JAX decouple the model and the parameters as you might have observed in the models defined previously. When a model is defined via the Sequential class, we first need to initialize the network to generate the parameters dictionary. We can access any layer’s parameters through the keys of this dictionary.

When a model is defined via the Sequential class, we can first access any layer by indexing into the model as though it were a list. Each layer’s parameters are conveniently located in its attribute.

我们可以如下检查第二个全连接层的参数。

net[2].state_dict()
OrderedDict([('weight',
       tensor([[-0.2523, 0.2104, 0.2189, -0.0395, -0.0590, 0.3360, -0.0205, -0.1507]])),
       ('bias', tensor([0.0694]))])
net[1].params
dense1_ (
 Parameter dense1_weight (shape=(1, 8), dtype=float32)
 Parameter dense1_bias (shape=(1,), dtype=float32)
)
params['params']['layers_2']
FrozenDict({
  kernel: Array([[-0.20739523],
      [ 0.16546965],
      [-0.03713543],
      [-0.04860032],
      [-0.2102929 ],
      [ 0.163712 ],
      [ 0.27240783],
      [-0.4046879 ]], dtype=float32),
  bias: Array([0.], dtype=float32),
})
net.layers[2].weights
[<tf.Variable 'dense_1/kernel:0' shape=(4, 1) dtype=float32, numpy=
 array([[-0.52124995],
    [-0.22314149],
    [ 0.20780373],
    [ 0.6839919 ]], dtype=float32)>,
 <tf.Variable 'dense_1/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]

我们可以看到这个全连接层包含两个参数,分别对应于该层的权重和偏差。

6.2.1.1. 目标参数

请注意,每个参数都表示为参数类的一个实例。要对参数做任何有用的事情,我们首先需要访问基础数值。做这件事有很多种方法。有些更简单,有些则更通用。以下代码从返回参数类实例的第二个神经网络层中提取偏差,并进一步访问该参数的值。

type(net[2].bias), net[2].bias.data
(torch.nn.parameter.Parameter, tensor([0.0694]))

参数是复杂的对象,包含值、梯度和附加信息这就是为什么我们需要显式请求该值。

除了值之外,每个参数还允许我们访问梯度。因为我们还没有为这个网络调用反向传播,所以它处于初始状态。

net[2].weight.grad == None
True
type(net[1].bias), net[1].bias.data()
(mxnet.gluon.parameter.Parameter, array([0.]))

Parameters are complex objects, containing values, gradients, and additional information. That is why we need to request the value explicitly.

In addition to the value, each parameter also allows us to access the gradient. Because we have not invoked backpropagation for this network yet, it is in its initial state.

net[1].weight.grad()
array([[0., 0., 0., 0., 0., 0., 0., 0.]])
bias = params['params']['layers_2']['bias']
type(bias), bias
(jaxlib.xla_extension.Array, Array([0.], dtype=float32))

Unlike the other frameworks, JAX does not keep a track of the gradients over the neural network parameters, instead the parameters and the network are decoupled. It allows the user to express their computation as a Python function, and use the grad transformation for the same purpose.

type(net.layers[2].weights[1]), tf.convert_to_tensor(net.layers[2].weights[1])
(tensorflow.python.ops.resource_variable_ops.ResourceVariable,
 <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>)

6.2.1.2. 一次所有参数

当我们需要对所有参数执行操作时,一个一个地访问它们会变得乏味。当我们使用更复杂的模块(例如,嵌套模块)时,情况会变得特别笨拙,因为我们需要递归遍历整个树以提取每个子模块的参数。下面我们演示访问所有层的参数。

[(name, param.shape) for name, param in net.named_parameters()]
[('0.weight', torch

评论

查看更多

下载排行

本周

  1. 1Keysight B1500A 半导体器件分析仪用户手册、说明书 (中文)
  2. 19.00 MB  |  4次下载  |  免费
  3. 2使用TL431设计电源
  4. 0.67 MB   |  2次下载  |  免费
  5. 3BT134双向可控硅手册
  6. 1.74 MB   |  2次下载  |  1 积分
  7. 4一种新型高效率的服务器电源系统
  8. 0.85 MB   |  1次下载  |  1 积分
  9. 5LabVIEW环形控件
  10. 0.01 MB   |  1次下载  |  1 积分
  11. 6PR735,使用UCC28060的600W交错式PFC转换器
  12. 540.03KB   |  1次下载  |  免费
  13. 751单片机核心板原理图
  14. 0.12 MB   |  1次下载  |  5 积分
  15. 8BP2879DB支持调光调灭的非隔离低 PF LED 驱动器
  16. 1.44 MB  |  1次下载  |  免费

本月

  1. 1开关电源设计原理手册
  2. 1.83 MB   |  54次下载  |  免费
  3. 2FS5080E 5V升压充电两串锂电池充电管理IC中文手册
  4. 8.45 MB   |  23次下载  |  免费
  5. 3DMT0660数字万用表产品说明书
  6. 0.70 MB   |  13次下载  |  免费
  7. 4UC3842/3/4/5电源管理芯片中文手册
  8. 1.75 MB   |  12次下载  |  免费
  9. 5ST7789V2单芯片控制器/驱动器英文手册
  10. 3.07 MB   |  11次下载  |  1 积分
  11. 6TPS54202H降压转换器评估模块用户指南
  12. 1.02MB   |  8次下载  |  免费
  13. 7STM32F101x8/STM32F101xB手册
  14. 1.69 MB   |  8次下载  |  1 积分
  15. 8基于MSP430FR6043的超声波气体流量计快速入门指南
  16. 2.26MB   |  7次下载  |  免费

总榜

  1. 1matlab软件下载入口
  2. 未知  |  935119次下载  |  10 积分
  3. 2开源硬件-PMP21529.1-4 开关降压/升压双向直流/直流转换器 PCB layout 设计
  4. 1.48MB  |  420061次下载  |  10 积分
  5. 3Altium DXP2002下载入口
  6. 未知  |  233084次下载  |  10 积分
  7. 4电路仿真软件multisim 10.0免费下载
  8. 340992  |  191367次下载  |  10 积分
  9. 5十天学会AVR单片机与C语言视频教程 下载
  10. 158M  |  183335次下载  |  10 积分
  11. 6labview8.5下载
  12. 未知  |  81581次下载  |  10 积分
  13. 7Keil工具MDK-Arm免费下载
  14. 0.02 MB  |  73807次下载  |  10 积分
  15. 8LabVIEW 8.6下载
  16. 未知  |  65987次下载  |  10 积分