博客
关于我
Pytorch_第三篇_Pytorch Autograd (自动求导机制)
阅读量:705 次
发布时间:2019-03-17

本文共 1949 字,大约阅读时间需要 6 分钟。

PyTorch Autograd (自动求导机制)

PyTorch 的 Autograd 库是训练神经网络时反向误差传播 (Backpropagation, BP) 算法的核心。在本文中,我们将通过 logistic 回归模型来理解 PyTorch 的自动求导机制。首先,我们将介绍与求导相关的 tensor 属性;其次,通过 logistic 回归模型来阐述前向传播和反向传播的过程。

Tensor Attributes Related to Derivation

在 PyTorch 中, tensor 对求导支持多种属性和操作。以下是一些关键属性和操作:

  • x.requires_grad:标记 tensor 是否需要在反向传播过程中求导。
  • with torch.no_grad(): 在模型评估时用于禁用求导,减少计算开销。
  • x.grad:存储损失函数对该 tensor 的偏导值,需在调用 backward() 后访问。
  • x.grad_fn:存储计算图中某些中间操作的函数,用于指导反向传播。
  • x.is_leaf:标记 tensor 是否为叶子张量。叶子张量通常是手动创建的参数,如神经网络中的权值矩阵。
  • x.detach():返回 tensor 的数据及 requires_grad 属性,返回的 tensor 与原 tensor 共享存储空间。建议使用此方法避免求导错误。
  • x.item():将 0维 张量(标量)转换为 Python 标量。
  • x.tolist():将张量转换为 Python 列表。

Build Logistic Regression Model

假设有以下损失函数:

[ z = w_1x_1 + w_2x_2 + b ]

[ y_p = \text{sigmoid}(z) ]

[ \text{Loss}(y_p, y_t) = -\frac{1}{n}\sum_{i=1}^n (y_t \log(y_p) + (1-y_t)\log(1-y_p)) ]

通过上述模型,我们可以构建一个简单的计算图:

  • 输入层通过权值参数 w1 和 w2 进行线性变换,输出层通过阈值参数 b 加上偏置。
  • 通过 sigmoid 函数将线性变换结果转换为概率输出。
  • 计算损失函数。
  • 在反向传播过程中,我们需要计算损失函数对权值参数 w1、w2 和阈值参数 b 的梯度,以便使用 SGD 等优化算法进行参数更新。

    PyTorch 实现

    以下是使用 PyTorch 实现 logistic 回归模型的代码示例:

    import torchimport numpy as npx_t, y_t = torch.tensor([[1, 1], [1, 0], [0, 1], [0, 0]], requires_grad=False, dtype=torch.float)y_t = torch.tensor([[0], [1], [0], [1]], requires_grad=False, dtype=torch.float)w = torch.randn([2, 1], requires_grad=True, dtype=torch.float)b = torch.zeros(1, requires_grad=True, dtype=torch.float)def logistic_model(x_t):    a = torch.matmul(x_t, w) + b    return torch.sigmoid(a)y_p = logistic_model(x_t)def get_loss(y_p, y_t):    return -torch.mean(y_t * torch.log(y_p) + (1 - y_t) * torch.log(1 - y_p))loss = get_loss(y_p, y_t)loss.backward()w.grad.zero_()b.grad.zero_()w.data -= 1e-2 * w.grad.datab.data -= 1e-2 * b.grad.dataprint(f"epoch: {e}, loss: {loss.item()}")print(w)print(b)

    代码解释:

  • 定义输入数据 tensor x_t 和目标输出 tensor y_t。
  • 初始化权值参数 w 和阈值参数 b。
  • 定义 logistic 回归模型函数。
  • 前向传播计算预测值 y_p。
  • 定义损失函数。
  • 反向传播计算梯度。
  • 更新权值和阈值参数。
  • 打印损失值和参数状态。
  • 运行代码后,我们可以观察到损失值随着迭代逐步下降,模型性能逐步提升。

    转载地址:http://yqvez.baihongyu.com/

    你可能感兴趣的文章
    param[:]=param-lr*param.grad/batch_size的理解
    查看>>
    spring mvc excludePathPatterns失效 如何解决spring拦截器失效 excludePathPatterns忽略失效 拦截器失效 spring免验证拦截器不起作用
    查看>>
    Spring Cloud 之注册中心 EurekaServerAutoConfiguration源码分析
    查看>>
    Parrot OS 6.2 重磅发布!推出全新 Docker 容器启动器
    查看>>
    Parrot OS 6.3 发布!全面提升安全性,新增先进工具,带来更高性能
    查看>>
    ParseChat应用源码ios版
    查看>>
    Part 2异常和错误
    查看>>
    Pascal Script
    查看>>
    Spring Boot集成Redis实现keyspace监听 | Spring Cloud 34
    查看>>
    Spring Boot中的自定义事件详解与实战
    查看>>
    Passport 密码模式
    查看>>
    Spring Boot(七十六):集成Redisson实现布隆过滤器(Bloom Filter)
    查看>>
    passwd命令限制用户密码到期时间
    查看>>
    Spring Boot 动态加载jar包,动态配置太强了!
    查看>>
    Spring @Async执行异步方法的简单使用
    查看>>
    PAT (Basic Level) Practice 乙级1021-1030
    查看>>
    PAT (Basic Level) Practice 乙级1031-1040
    查看>>
    PAT (Basic Level) Practice 乙级1041-1045
    查看>>
    SparkSql的元数据
    查看>>
    PAT (Basic Level) Practice 乙级1051-1055
    查看>>