0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

PyTorch教程-3.2. 面向对象的设计实现

jf_pJlTbmA9 来源:PyTorch 作者:PyTorch 2023-06-05 15:38 次阅读

在我们对线性回归的介绍中,我们介绍了各种组件,包括数据、模型、损失函数和优化算法。事实上,线性回归是最简单的机器学习模型之一。然而,训练它使用许多与本书中其他模型所需的组件相同的组件。因此,在深入了解实现细节之前,有必要设计一些贯穿本书的 API。将深度学习中的组件视为对象,我们可以从为这些对象及其交互定义类开始。这种面向对象的实现设计将极大地简化演示,您甚至可能想在您的项目中使用它。

受PyTorch Lightning等开源库的启发,在高层次上我们希望拥有三个类:(i)Module包含模型、损失和优化方法;(ii)DataModule提供用于训练和验证的数据加载器;(iii) 两个类结合使用该类 Trainer,这使我们能够在各种硬件平台上训练模型。本书中的大部分代码都改编自Moduleand DataModule。Trainer只有在讨论 GPUCPU、并行训练和优化算法时,我们才会涉及该类。

import time
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l

import time
import numpy as np
from mxnet.gluon import nn
from d2l import mxnet as d2l

import time
from dataclasses import field
from typing import Any
import jax
import numpy as np
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
from d2l import jax as d2l

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

import time
import numpy as np
import tensorflow as tf
from d2l import torch as d2l

3.2.1. 公用事业

我们需要一些实用程序来简化 Jupyter 笔记本中的面向对象编程。挑战之一是类定义往往是相当长的代码块。笔记本电脑的可读性需要简短的代码片段,穿插着解释,这种要求与 Python 库常见的编程风格不相容。第一个实用函数允许我们在创建类后将函数注册为类中的方法。事实上,即使我们已经创建了类的实例,我们也可以这样做!它允许我们将一个类的实现拆分成多个代码块。

def add_to_class(Class): #@save
  """Register functions as methods in created class."""
  def wrapper(obj):
    setattr(Class, obj.__name__, obj)
  return wrapper

让我们快速浏览一下如何使用它。我们计划 A用一个方法来实现一个类do。我们可以先声明类并创建一个实例,而不是在同一个代码块中A同时 拥有两者的代码。doAa

class A:
  def __init__(self):
    self.b = 1

a = A()

do接下来我们像往常一样 定义方法,但不在 classA的范围内。相反,我们add_to_class用类A作为参数来装饰这个方法。这样做时,该方法能够访问 的成员变量,A正如我们所期望的那样,如果它已被定义为 的A定义的一部分。让我们看看当我们为实例调用它时会发生什么a。

@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()

Class attribute "b" is 1

@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()

Class attribute "b" is 1

@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()

Class attribute "b" is 1

@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()

Class attribute "b" is 1

第二个是实用程序类,它将类 __init__方法中的所有参数保存为类属性。这使我们无需额外代码即可隐式扩展构造函数调用签名。

class HyperParameters: #@save
  """The base class of hyperparameters."""
  def save_hyperparameters(self, ignore=[]):
    raise NotImplemented

我们将其实施推迟到第 23.7 节。HyperParameters要使用它,我们定义继承自该方法并调用 save_hyperparameters该方法的类__init__。

# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)

self.a = 1 self.b = 2
There is no self.c = True

# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)

self.a = 1 self.b = 2
There is no self.c = True

# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)

self.a = 1 self.b = 2
There is no self.c = True

# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)

self.a = 1 self.b = 2
There is no self.c = True

最后一个实用程序允许我们在实验进行时以交互方式绘制实验进度。为了尊重更强大(和复杂)的TensorBoard,我们将其命名为ProgressBoard。实现推迟到 第 23.7 节。现在,让我们简单地看看它的实际效果。

该方法在图中 draw绘制一个点,并在图例中指定。可选的仅通过显示来平滑线条(x, y)labelevery_n1/n图中的点。他们的价值是从平均n原始图中的邻居点。

class ProgressBoard(d2l.HyperParameters): #@save
  """The board that plots data points in animation."""
  def __init__(self, xlabel=None, ylabel=None, xlim=None,
         ylim=None, xscale='linear', yscale='linear',
         ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
         fig=None, axes=None, figsize=(3.5, 2.5), display=True):
    self.save_hyperparameters()

  def draw(self, x, y, label, every_n=1):
    raise NotImplemented

在下面的示例中,我们以不同的平滑度绘制sin和。cos如果你运行这个代码块,你会看到线条在动画中增长。

board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
  board.draw(x, np.sin(x), 'sin', every_n=2)
  board.draw(x, np.cos(x), 'cos', every_n=10)

pYYBAGR5VF-ARqZBAAFz158O3us302.svg

board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
  board.draw(x, np.sin(x), 'sin', every_n=2)
  board.draw(x, np.cos(x), 'cos', every_n=10)

pYYBAGR5VF-ARqZBAAFz158O3us302.svg

board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
  board.draw(x, np.sin(x), 'sin', every_n=2)
  board.draw(x, np.cos(x), 'cos', every_n=10)

pYYBAGR5VF-ARqZBAAFz158O3us302.svg

board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
  board.draw(x, np.sin(x), 'sin', every_n=2)
  board.draw(x, np.cos(x), 'cos', every_n=10)

pYYBAGR5VF-ARqZBAAFz158O3us302.svg

3.2.2. 楷模

该类Module是我们将要实现的所有模型的基类。我们至少需要定义三个方法。该__init__方法存储可学习参数,该training_step方法接受数据批次以返回损失值,该方法configure_optimizers返回优化方法或它们的列表,用于更新可学习参数。我们可以选择定义 validation_step报告评估措施。有时我们将计算输出的代码放入一个单独的forward方法中,以使其更具可重用性。

class Module(nn.Module, d2l.HyperParameters): #@save
  """The base class of models."""
  def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
    super().__init__()
    self.save_hyperparameters()
    self.board = ProgressBoard()

  def loss(self, y_hat, y):
    raise NotImplementedError

  def forward(self, X):
    assert hasattr(self, 'net'), 'Neural network is defined'
    return self.net(X)

  def plot(self, key, value, train):
    """Plot a point in animation."""
    assert hasattr(self, 'trainer'), 'Trainer is not inited'
    self.board.xlabel = 'epoch'
    if train:
      x = self.trainer.train_batch_idx / 
        self.trainer.num_train_batches
      n = self.trainer.num_train_batches / 
        self.plot_train_per_epoch
    else:
      x = self.trainer.epoch + 1
      n = self.trainer.num_val_batches / 
        self.plot_valid_per_epoch
    self.board.draw(x, value.to(d2l.cpu()).detach().numpy(),
            ('train_' if train else 'val_') + key,
            every_n=int(n))

  def training_step(self, batch):
    l = self.loss(self(*batch[:-1]), batch[-1])
    self.plot('loss', l, train=True)
    return l

  def validation_step(self, batch):
    l = self.loss(self(*batch[:-1]), batch[-1])
    self.plot('loss', l, train=False)

  def configure_optimizers(self):
    raise NotImplementedError

您可能会注意到它Module是nn.ModulePyTorch 中神经网络基类的子类。它提供了方便的功能来处理神经网络。例如,如果我们定义一个forward方法,例如,那么对于一个实例,我们可以通过 调用这个方法。这是有效的,因为它调用 内置方法中的方法。您可以在第 6.1 节中找到更多详细信息和示例。forward(self, X)aa(X)forward__call__nn.Module

class Module(nn.Block, d2l.HyperParameters): #@save
  """The base class of models."""
  def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
    super().__init__()
    self.save_hyperparameters()
    self.board = ProgressBoard()
  def loss(self, y_hat, y):
    raise NotImplementedError

  def forward(self, X):
    assert hasattr(self, 'net'), 'Neural network is defined'
    return self.net(X)

  def plot(self, key, value, train):
    """Plot a point in animation."""
    assert hasattr(self, 'trainer'), 'Trainer is not inited'
    self.board.xlabel = 'epoch'
    if train:
      x = self.trainer.train_batch_idx / 
        self.trainer.num_train_batches
      n = self.trainer.num_train_batches / 
        self.plot_train_per_epoch
    else:
      x = self.trainer.epoch + 1
      n = self.trainer.num_val_batches / 
        self.plot_valid_per_epoch
    self.board.draw(x, value.asnumpy(), (
      'train_' if train else 'val_') + key, every_n=int(n))
  def training_step(self, batch):
    l = self.loss(self(*batch[:-1]), batch[-1])
    self.plot('loss', l, train=True)
    return l

  def validation_step(self, batch):
    l = self.loss(self(*batch[:-1]), batch[-1])
    self.plot('loss', l, train=False)

  def configure_optimizers(self):
    raise NotImplementedError

You may notice that Module is a subclass of nn.Block, the base class of neural networks in Gluon. It provides convenient features to handle neural networks. For example, if we define a forward method, such as forward(self, X), then for an instance a we can invoke this method by a(X). This works since it calls the forward method in the built-in __call__ method. You can find more details and examples about nn.Block in Section 6.1.

With the introduction of dataclasses in Python 3.7, classes decorated with @dataclass automatically add magic methods such as __init__ and __repr__. The member variables are defined using type annotations. All Flax modules are Python 3.7 dataclasses.

class Module(nn.Module, d2l.HyperParameters): #@save
  """The base class of models."""
  # No need for save_hyperparam when using Python dataclass
  plot_train_per_epoch: int = field(default=2, init=False)
  plot_valid_per_epoch: int = field(default=1, init=False)
  # Use default_factory to make sure new plots are generated on each run
  board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
                 init=False)

  def loss(self, y_hat, y):
    raise NotImplementedError

  # JAX & Flax do not have a forward-method-like syntax. Flax uses setup
  # and built-in __call__ magic methods for forward pass. Adding here
  # for consistency
  def forward(self, X, *args, **kwargs):
    assert hasattr(self, 'net'), 'Neural network is defined'
    return self.net(X, *args, **kwargs)

  def __call__(self, X, *args, **kwargs):
    return self.forward(X, *args, **kwargs)

  def plot(self, key, value, train):
    """Plot a point in animation."""
    assert hasattr(self, 'trainer'), 'Trainer is not inited'
    self.board.xlabel = 'epoch'
    if train:
      x = self.trainer.train_batch_idx / 
        self.trainer.num_train_batches
      n = self.trainer.num_train_batches / 
        self.plot_train_per_epoch
    else:
      x = self.trainer.epoch + 1
      n = self.trainer.num_val_batches / 
        self.plot_valid_per_epoch
    self.board.draw(x, jax.device_put(value, d2l.cpu()),
            ('train_' if train else 'val_') + key,
            every_n=int(n))

  def training_step(self, params, batch, state):
    l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
                         batch[-1], state)
    self.plot("loss", l, train=True)
    return l, grads

  def validation_step(self, params, batch, state):
    l = self.loss(params, batch[:-1], batch[-1], state)
    self.plot('loss', l, train=False)

  def apply_init(self, dummy_input, key):
    """To be defined later in :numref:`sec_lazy_init`"""
    raise NotImplementedError

  def configure_optimizers(self):
    raise NotImplementedError

You may notice that Module is a subclass of linen.Module, the base class of neural networks in Flax. It provides convenient features to handle neural networks. For example, it handles the model parameters, provides the nn.compact decorator to simplify code, invokes the __call__ method among other things. Here we also redirect __call__ to the forward method. We do this to make our code more similar to other framework implementations.

class Module(tf.keras.Model, d2l.HyperParameters): #@save
  """The base class of models."""
  def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
    super().__init__()
    self.save_hyperparameters()
    self.board = ProgressBoard()
    self.training = None

  def loss(self, y_hat, y):
    raise NotImplementedError

  def forward(self, X):
    assert hasattr(self, 'net'), 'Neural network is defined'
    return self.net(X)

  def call(self, X, *args, **kwargs):
    if kwargs and "training" in kwargs:
      self.training = kwargs['training']
    return self.forward(X, *args)

  def plot(self, key, value, train):
    """Plot a point in animation."""
    assert hasattr(self, 'trainer'), 'Trainer is not inited'
    self.board.xlabel = 'epoch'
    if train:
      x = self.trainer.train_batch_idx / 
        self.trainer.num_train_batches
      n = self.trainer.num_train_batches / 
        self.plot_train_per_epoch
    else:
      x = self.trainer.epoch + 1
      n = self.trainer.num_val_batches / 
        self.plot_valid_per_epoch
    self.board.draw(x, value.numpy(), (
      'train_' if train else 'val_') + key, every_n=int(n))
  def training_step(self, batch):
    l = self.loss(self(*batch[:-1]), batch[-1])
    self.plot('loss', l, train=True)
    return l

  def validation_step(self, batch):
    l = self.loss(self(*batch[:-1]), batch[-1])
    self.plot('loss', l, train=False)

  def configure_optimizers(self):
    raise NotImplementedError

You may notice that Module is a subclass of tf.keras.Model, the base class of neural networks in TensorFlow. It provides convenient features to handle neural networks. For example, it invokes the call method in the built-in __call__ method. Here we redirect call to the forward method, saving its arguments as a class attribute. We do this to make our code more similar to other framework implementations.

3.2.3. 数据

该类DataModule是数据的基类。该方法经常__init__用于准备数据。如果需要,这包括下载和预处理。返回train_dataloader 训练数据集的数据加载器。数据加载器是一个 (Python) 生成器,每次使用时都会生成一个数据批次。然后将该批次输入到计算损失training_step的方法中。Module有一个val_dataloader返回验证数据集加载器的选项。它的行为方式相同,只是它为validation_step中的方法生成数据批次Module。

class DataModule(d2l.HyperParameters): #@save
  """The base class of data."""
  def __init__(self, root='../data', num_workers=4):
    self.save_hyperparameters()

  def get_dataloader(self, train):
    raise NotImplementedError

  def train_dataloader(self):
    return self.get_dataloader(train=True)

  def val_dataloader(self):
    return self.get_dataloader(train=False)

class DataModule(d2l.HyperParameters): #@save
  """The base class of data."""
  def __init__(self, root='../data', num_workers=4):
    self.save_hyperparameters()

  def get_dataloader(self, train):
    raise NotImplementedError

  def train_dataloader(self):
    return self.get_dataloader(train=True)

  def val_dataloader(self):
    return self.get_dataloader(train=False)

class DataModule(d2l.HyperParameters): #@save
  """The base class of data."""
  def __init__(self, root='../data'):
    self.save_hyperparameters()

  def get_dataloader(self, train):
    raise NotImplementedError

  def train_dataloader(self):
    return self.get_dataloader(train=True)

  def val_dataloader(self):
    return self.get_dataloader(train=False)

class DataModule(d2l.HyperParameters): #@save
  """The base class of data."""
  def __init__(self, root='../data'):
    self.save_hyperparameters()

  def get_dataloader(self, train):
    raise NotImplementedError

  def train_dataloader(self):
    return self.get_dataloader(train=True)

  def val_dataloader(self):
    return self.get_dataloader(train=False)

3.2.4. 训练

该类 使用中指定的数据Trainer训练类中的可学习参数。关键方法是,它接受两个参数:,一个实例,和 ,一个实例。然后它遍历整个数据集时间来训练模型。和以前一样,我们将把这个方法的实现推迟到后面的章节。ModuleDataModulefitmodelModuledataDataModulemax_epochs

class Trainer(d2l.HyperParameters): #@save
  """The base class for training models with data."""
  def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
    self.save_hyperparameters()
    assert num_gpus == 0, 'No GPU support yet'

  def prepare_data(self, data):
    self.train_dataloader = data.train_dataloader()
    self.val_dataloader = data.val_dataloader()
    self.num_train_batches = len(self.train_dataloader)
    self.num_val_batches = (len(self.val_dataloader)
                if self.val_dataloader is not None else 0)

  def prepare_model(self, model):
    model.trainer = self
    model.board.xlim = [0, self.max_epochs]
    self.model = model

  def fit(self, model, data):
    self.prepare_data(data)
    self.prepare_model(model)
    self.optim = model.configure_optimizers()
    self.epoch = 0
    self.train_batch_idx = 0
    self.val_batch_idx = 0
    for self.epoch in range(self.max_epochs):
      self.fit_epoch()

  def fit_epoch(self):
    raise NotImplementedError

The Trainer class trains the learnable parameters in the Module class with data specified in DataModule. The key method is fit, which accepts two arguments: model, an instance of Module, and data, an instance of DataModule. It then iterates over the entire dataset max_epochs times to train the model. As before, we will defer the implementation of this method to later chapters.

class Trainer(d2l.HyperParameters): #@save
  """The base class for training models with data."""
  def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
    self.save_hyperparameters()
    assert num_gpus == 0, 'No GPU support yet'

  def prepare_data(self, data):
    self.train_dataloader = data.train_dataloader()
    self.val_dataloader = data.val_dataloader()
    self.num_train_batches = len(self.train_dataloader)
    self.num_val_batches = (len(self.val_dataloader)
                if self.val_dataloader is not None else 0)

  def prepare_model(self, model):
    model.trainer = self
    model.board.xlim = [0, self.max_epochs]
    self.model = model

  def fit(self, model, data):
    self.prepare_data(data)
    self.prepare_model(model)
    self.optim = model.configure_optimizers()
    self.epoch = 0
    self.train_batch_idx = 0
    self.val_batch_idx = 0
    for self.epoch in range(self.max_epochs):
      self.fit_epoch()

  def fit_epoch(self):
    raise NotImplementedError

The Trainer class trains the learnable parameters params with data specified in DataModule. The key method is fit, which accepts three arguments: model, an instance of Module, data, an instance of DataModule, and key, a JAX PRNGKeyArray. We make the key argument optional here to simplify the interface, but it is recommended to always pass and initialize the model parameters with a root key in JAX and Flax. It then iterates over the entire dataset max_epochs times to train the model. As before, we will defer the implementation of this method to later chapters.

class Trainer(d2l.HyperParameters): #@save
  """The base class for training models with data."""
  def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
    self.save_hyperparameters()
    assert num_gpus == 0, 'No GPU support yet'

  def prepare_data(self, data):
    self.train_dataloader = data.train_dataloader()
    self.val_dataloader = data.val_dataloader()
    self.num_train_batches = len(self.train_dataloader)
    self.num_val_batches = (len(self.val_dataloader)
                if self.val_dataloader is not None else 0)

  def prepare_model(self, model):
    model.trainer = self
    model.board.xlim = [0, self.max_epochs]
    self.model = model

  def fit(self, model, data, key=None):
    self.prepare_data(data)
    self.prepare_model(model)
    self.optim = model.configure_optimizers()

    if key is None:
      root_key = d2l.get_key()
    else:
      root_key = key
    params_key, dropout_key = jax.random.split(root_key)
    key = {'params': params_key, 'dropout': dropout_key}

    dummy_input = next(iter(self.train_dataloader))[:-1]
    variables = model.apply_init(dummy_input, key=key)
    params = variables['params']

    if 'batch_stats' in variables.keys():
      # Here batch_stats will be used later (e.g., for batch norm)
      batch_stats = variables['batch_stats']
    else:
      batch_stats = {}

    # Flax uses optax under the hood for a single state obj TrainState.
    # More will be discussed later in the dropout and batch
    # normalization section
    class TrainState(train_state.TrainState):
      batch_stats: Any
      dropout_rng: jax.random.PRNGKeyArray

    self.state = TrainState.create(apply_fn=model.apply,
                    params=params,
                    batch_stats=batch_stats,
                    dropout_rng=dropout_key,
                    tx=model.configure_optimizers())
    self.epoch = 0
    self.train_batch_idx = 0
    self.val_batch_idx = 0
    for self.epoch in range(self.max_epochs):
      self.fit_epoch()

  def fit_epoch(self):
    raise NotImplementedError

The Trainer class trains the learnable parameters in the Module class with data specified in DataModule. The key method is fit, which accepts two arguments: model, an instance of Module, and data, an instance of DataModule. It then iterates over the entire dataset max_epochs times to train the model. As before, we will defer the implementation of this method to later chapters.

class Trainer(d2l.HyperParameters): #@save
  """The base class for training models with data."""
  def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
    self.save_hyperparameters()
    assert num_gpus == 0, 'No GPU support yet'

  def prepare_data(self, data):
    self.train_dataloader = data.train_dataloader()
    self.val_dataloader = data.val_dataloader()
    self.num_train_batches = len(self.train_dataloader)
    self.num_val_batches = (len(self.val_dataloader)
                if self.val_dataloader is not None else 0)

  def prepare_model(self, model):
    model.trainer = self
    model.board.xlim = [0, self.max_epochs]
    self.model = model

  def fit(self, model, data):
    self.prepare_data(data)
    self.prepare_model(model)
    self.optim = model.configure_optimizers()
    self.epoch = 0
    self.train_batch_idx = 0
    self.val_batch_idx = 0
    for self.epoch in range(self.max_epochs):
      self.fit_epoch()

  def fit_epoch(self):
    raise NotImplementedError

3.2.5. 概括

为了突出我们未来深度学习实现的面向对象设计,上面的类只是展示了它们的对象如何存储数据和相互交互。@add_to_class我们将在本书的其余部分继续丰富这些类的实现,例如 via 。此外,这些完全实现的类保存在d2l 库中,d2l 库是一个 轻量级工具包,可以轻松进行深度学习的结构化建模。特别是,它有助于在项目之间重用许多组件,而无需进行太多更改。例如,我们可以只替换优化器、模型、数据集等;这种程度的模块化在简洁和简单方面为整本书带来了好处(这就是我们添加它的原因),它可以为您自己的项目做同样的事情。

3.2.6. 练习

找到保存在d2l 库中的上述类的完整实现。我们强烈建议您在对深度学习建模有一定的了解后,再详细查看实现。

删除类save_hyperparameters中的语句B。你还能打印self.aandself.b吗?可选:如果您已经深入了解该类的完整实现HyperParameters,您能解释一下原因吗?

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • pytorch
    +关注

    关注

    2

    文章

    803

    浏览量

    13157
收藏 人收藏

    评论

    相关推荐

    Python的面向对象编程详解

    一般编程可分为面向过程编程,和面向对象编程。Python的面向对象编程,与Java的面向
    发表于 09-04 16:35 527次阅读
    Python的<b class='flag-5'>面向</b><b class='flag-5'>对象</b>编程详解

    利用LabVIEW工程库实现面向对象编程

    利用LabVIEW工程库实现面向对象编程利用LabVIEW工程库实现面向对象编程注意: 我写这篇
    发表于 12-06 12:41

    LabVIEW面向对象的ActorFramework(1)

    ` 本帖最后由 bollworm 于 2020-2-10 14:54 编辑 本系列文章主要阐述以下几个问题:(1)什么是面向对象编程?(2)为什么要学习面向编程?(3)LabVIEW面向
    发表于 02-10 14:09

    LabVIEW面向对象的ActorFramework(2)

    二、为什么要学习面向编程?面向对象编程,如果将上文推荐的两本书读完后,基本上也就有了答案。从从自我产品开发的经验中,理解为可以迅速解决中大型程序需求变化时,在不影响其他程序功能的情况下,能够
    发表于 02-18 09:20

    Labview面向对象的思考方式

    面向过程和面向对象编程的思维方式用把大象装进冰箱来描述1、面向过程的思维方式:第一步:打开冰箱门第二步:把大象推进去第三步:关上冰箱门2、面向
    发表于 04-16 14:02

    如何用C语言实现面向对象编程

    1 用C语言实现面向对象编程GOF的《设计模式》一书的副标题叫做“可复用面向对象软件的基础”,从标题就能看出
    发表于 07-12 07:24

    c语言实现面向对象编程 精选资料分享

    差异。在语法上,C语言支持的oop(面向对象)机制比较薄弱,但完全可以使用c语言写出面向对象的程序,只不过很多细节没有语法支持,需要编程人自己去实现
    发表于 09-02 07:46

    谈谈面向对象编程

    工业控制系统的PLC程序中也可以采用这种设计思想,虽然我们无法实现面向对象的很多优秀特点如“继承”,甚至于它根本就不具备面向对象编程语言的特
    发表于 09-08 07:47

    面向对象编程语言的特点

    工业控制系统的PLC程序中也可以采用这种设计思想,虽然我们无法实现面向对象的很多优秀特点如“继承”,甚至于它根本就不具备面向对象编程语言的特
    发表于 09-08 07:44

    面向对象编程介绍

    目录一、面向对象编程介绍1.面向过程编程2.函数式编程3.面向对象编程二.面向
    发表于 12-13 07:22

    plc面向对象编程架构与实现

    面向对象编程是计算机高级语言的一种先进的编程模式,在工业控制系统的PLC程序中也可以采用这种设计思想,虽然我们无法实现面向对象的很多优秀特点
    发表于 01-31 15:00 4241次阅读
    plc<b class='flag-5'>面向</b><b class='flag-5'>对象</b>编程架构与<b class='flag-5'>实现</b>

    利用Python和PyTorch处理面向对象的数据集

    本篇是利用 Python 和 PyTorch 处理面向对象的数据集系列博客的第 2 篇。 如需阅读第 1 篇:原始数据和数据集,请参阅此处。 我们在第 1 部分中已定义 MyDataset 类,现在
    的头像 发表于 08-25 15:30 2960次阅读

    利用 Python 和 PyTorch 处理面向对象的数据集(2)) :创建数据集对象

    本篇是利用 Python 和 PyTorch 处理面向对象的数据集系列博客的第 2 篇。我们在第 1 部分中已定义 MyDataset 类,现在,让我们来例化 MyDataset 对象
    的头像 发表于 08-02 17:35 903次阅读
    利用 Python 和 <b class='flag-5'>PyTorch</b> 处理<b class='flag-5'>面向</b><b class='flag-5'>对象</b>的数据集(2)) :创建数据集<b class='flag-5'>对象</b>

    PyTorch教程3.2面向对象的设计实现

    电子发烧友网站提供《PyTorch教程3.2面向对象的设计实现.pdf》资料免费下载
    发表于 06-05 15:48 0次下载
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>3.2</b>之<b class='flag-5'>面向</b><b class='flag-5'>对象</b>的设计<b class='flag-5'>实现</b>

    PyTorch教程14.6之对象检测数据集

    电子发烧友网站提供《PyTorch教程14.6之对象检测数据集.pdf》资料免费下载
    发表于 06-05 11:23 0次下载
    <b class='flag-5'>PyTorch</b>教程14.6之<b class='flag-5'>对象</b>检测数据集