循环神经网络RNN原理与优化

news/2025/2/22 3:32:47

目录

前言

RNN背景

RNN原理

上半部分:RNN结构及按时间线展开图

下半部分:RNN在不同时刻的网络连接和计算过程

LSTM

RNN存在的问题

LSTM的结构与原理

数学表达层面

与RNN对比优势

应用场景拓展

从简易但严谨的代码来看RNN和LSTM

RNN

LSTM


前言

绕循环神经网络(RNN)、注意力机制(Attention)以及相关模型(如 LSTM、Transformer、BERT、GPT 等)在深度学习中的应用展开,介绍了其原理、结构、算法流程和实际应用场景。

RNN背景

RNN 产生的原因深度神经网络(DNN)在处理输入时,每个输入之间相互独立,无法处理序列信息。然而在自然语言处理(NLP)和视频处理等任务中,需要考虑输入元素之间的关联性,因此引入 RNN。以 NLP 中的词性标注任务为例,需处理单词序列才能准确标注词性,仅单独理解每个单词是不够的。

RNN 的结构与公式:RNN 在结构上引入了循环层,其隐藏层状态St不仅取决于当前输入Xt,还与上一时刻的隐藏层状态相关。具体公式为

输出

这样结构使RNN网络能够序列信息进行处理

RNN原理

上半部分:RNN结构及按时间线展开图

RNN结构:

输入层(Input Layer):标记为“x”,接收输入数据。

隐藏层(Hidden Layer):标记为“s”,是RNN的核心部分,包含循环连接。图中显示了权重矩阵“U”(连接输入层和隐藏层)和“W”(隐藏层的循环连接)(输入层的)

输出层(Output Layer):标记为“o”,通过权重矩阵“V”与隐藏层相连(输出层的),产生最终输出。

按时间线展开:

将RNN在时间维度上展开,展示了不同时刻(t-1, t, t+1)的网络状态。每个时刻都有输入x_t、隐藏层状态s_t和输出o_t。权重矩阵“U”、“W”和“V”在不同时刻保持不变,体现了RNN在时间上共享参数的特性。

下半部分:RNN在不同时刻的网络连接和计算过程

t-1时刻

展示了隐藏层状态s的向量形式,s=[s1, s2, ..., sn],其中每个元素代表隐藏层的一个神经元状态。权重矩阵“W”连接了t-1时刻的隐藏层神经元。

t时刻

输入层:输入向量X=[x1, x2, ..., xm],其中m是输入维度。

隐藏层:通过权重矩阵“U”接收输入层的信息,并通过权重矩阵“W”接收t-1时刻的隐藏层状态信息。图中显示了隐藏层的计算过程,即

其中f是激活函数。

输出层:根据隐藏层状态S_t,通过权重矩阵“V”计算输出

其中g是输出层的激活函数。

LSTM

LSTM(Long - Short - Term Memory,长短期记忆网络)是为解决传统循环神经网络(RNN)存在的问题而设计的。

RNN存在的问题

RNN有两个主要问题。一是短期记忆问题,当处理足够长的序列时,它难以将早期时间步的信息传递到后期。比如处理一段文本进行预测时,可能会遗漏开头的重要信息。二是梯度消失问题,在反向传播过程中,梯度随着时间反向传播而缩小。当梯度值变得极小,对神经网络权重更新的贡献就很小,导致早期的层停止学习,这也使得RNN在处理长序列时容易遗忘之前的信息。

LSTM的结构与原理

输入

当前时刻输出保存当前细胞状态(传递给下一个‘细胞’)

LSTM通过引入“细胞状态(cell state)”和“门(gate)”机制来解决上述问题:

细胞状态:就像一条传送带,在整个网络中运行,它可以在序列的不同时间步之间传递信息,使得LSTM能够处理长序列而不容易丢失早期信息。

门:

遗忘门(forget gate):决定从细胞状态中丢弃哪些信息。它读取当前输入和上一时刻隐藏状态,输出一个0 - 1之间的值,1表示“完全保留”,0表示“完全丢弃”。

输入门(input gate):确定要在细胞状态中存储哪些新信息。它包含一个sigmoid层来决定更新哪些值,以及一个tanh层来创建新的候选值向量,这些候选值可能会被添加到细胞状态中。

输出门(output gate):确定LSTM的输出。它首先通过sigmoid层决定细胞状态的哪些部分将被输出,然后将细胞状态通过tanh层(将值映射到 - 1到1之间),并将其与sigmoid层的输出相乘,得到最终的输出。

通过这些机制,LSTM能够更好地处理长序列数据,有选择性地记忆和遗忘信息,有效克服了RNN的短期记忆和梯度消失问题,这也是LSTM在后续的一些自然语言处理、语音识别等领域得到广泛应用的主要原因。

数学表达层面

遗忘门计算:

,其中W_f是权重矩阵,[h_{t - 1},x_t]是上一时刻隐藏状态和当前输入的拼接(‘细胞’传递),b_f是偏置项(截距),sigma是sigmoid激活函数,输出值在0 - 1之间,决定从细胞状态中遗忘的信息比例。

输入门计算:

确定更新值比例,

生成候选值向量,二者后续用于更新细胞状态。

细胞状态更新:

是逐元素相乘,即结合遗忘门输出、上一时刻细胞状态、输入门输出和候选值向量来更新细胞状态。

输出门计算:

决定输出比例,

得到最终隐藏状态输出。

与RNN对比优势

长期依赖处理:RNN受限于梯度消失难以保持长期依赖,LSTM通过门控机制控制细胞状态信息流,能有效保存和传递长距离信息,比如在处理长篇小说文本时,可记住开头人物关系等信息用于后续情节理解和生成。

学习效率:RNN因梯度问题早期层学习困难,LSTM通过门控灵活控制信息流动,更高效学习,在训练时间和收敛速度上表现更好,在语音识别任务中,可更快学习到语音序列中的特征模式。

应用场景拓展

自然语言处理:除常见的文本生成、机器翻译、情感分析,在文本摘要提取中,能抓住长文本关键信息;在命名实体识别中,准确识别不同类型实体。

时间序列预测:在金融领域,预测股票价格、汇率等波动;在能源领域,预测电力负荷、能源消耗等,利用其对时间序列中长短期信息的捕捉能力提高预测准确性。

视频处理:分析视频帧序列,用于动作识别、视频内容理解与生成,如判断视频中人物动作类别,生成符合逻辑的视频字幕等。

从简易但严谨的代码来看RNN和LSTM

通过pytorch框架定义只有一个’细胞RNNLSTM,进一步理解这两个网络架构应用

RNN

import torch
import torch.nn as nn

# 定义RNN模型
class SimpleRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(SimpleRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.rnn(x, h0)
        out = self.fc(out[:, -1, :])
        return out
        
# 示例参数
input_size = 10
hidden_size = 20
num_layers = 1
output_size = 5
batch_size = 3
seq_length = 8

# 创建输入数据
x = torch.randn(batch_size, seq_length, input_size)

# 实例化RNN模型
model = SimpleRNN(input_size, hidden_size, num_layers, output_size)

# 前向传播
output = model(x)
print(output.shape)

说明

定义了一个简单的SimpleRNN类继承自nn.Module。在构造函数中,初始化了 RNN 层和全连接层。nn.RNN指定了输入维度input_size、隐藏层维度hidden_size、层数num_layers,并设置batch_first=True表示输入数据的形状为(batch_size, seq_length, input_size)。

forward方法中,首先初始化隐藏状态h0,然后将输入数据x和初始隐藏状态传入 RNN 层,获取输出out。最后将 RNN 最后一个时间步的输出传入全连接层得到最终输出。

LSTM

import torch
import torch.nn as nn

# 定义LSTM模型
class SimpleLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(SimpleLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.lstm(x, (h0, c0))
        out = self.fc(out[:, -1, :])
        return out
        
# 示例参数
input_size = 10
hidden_size = 20
num_layers = 1
output_size = 5
batch_size = 3
seq_length = 8

# 创建输入数据
x = torch.randn(batch_size, seq_length, input_size)

# 实例化LSTM模型
model = SimpleLSTM(input_size, hidden_size, num_layers, output_size)

# 前向传播
output = model(x)
print(output.shape)

说明

1.定义了SimpleLSTM类,同样继承自nn.Module。构造函数中初始化了 LSTM 层和全连接层,nn.LSTM的参数设置与 RNN 类似。

2.forward方法里,除了初始化隐藏状态h0,还初始化了细胞状态c0,然后将输入x、h0和c0传入 LSTM 层,获取输出out,最后经全连接层得到最终结果。


http://www.niftyadmin.cn/n/5861560.html

相关文章

[Android]NestedScrollView嵌套RecyclerView视图点击事件冲突问题

解决: package com.mofsaas.www.ui.adapterimport android.annotation.SuppressLint import android.view.LayoutInflater import android.view.MotionEvent import android.view.View import android.view.ViewGroup import android.widget.Button import android…

网工项目实践2.4 北京公司安全加固、服务需求分析及方案制定

本专栏持续更新,整一个专栏为一个大型复杂网络工程项目。阅读本文章之前务必先看《本专栏必读》。 全网拓扑图展示 一.局域网规划设计 1.子公司北京总部局域网安全加固、网络服务需求 子公司北京总部在与运营商边界需要部署一台防火墙,保护内网的安全。…

RoboBERT:减少大规模数据与训练成本,端到端多模态机器人操作模型(西湖大学最新)

写在前面&出发点 具身智能融合多种模态,使智能体能够同时理解图像、语言和动作。然而,现有模型通常依赖额外数据集或大量预训练来最大化性能提升,这耗费了大量训练时间和高昂的硬件成本。为解决这一问题,我们提出RoboBERT&…

DeepSeek 助力 Vue 开发:打造丝滑的复制到剪贴板(Copy to Clipboard)

前言:哈喽,大家好,今天给大家分享一篇文章!并提供具体代码帮助大家深入理解,彻底掌握!创作不易,如果能帮助到大家或者给大家一些灵感和启发,欢迎收藏关注哦 💕 目录 Deep…

250217-数据结构

1. 定义 数据结构是数据的存储结构,即数据是按某些结构来存储的,比如线性结构,比如树状结构等。 2. 学习意义 数据结构是服务于算法的,为了实现算法的高效计算,所以将数据按特定结构存储。比如使用快速插入或删除的…

OpenSSL has been compiled without RC2 support

记性不好,以此记录日常遇到的问题 问题 原因 研究发现linux的OpenSSL版本过高,已经抛弃了RC2 方法 通过conda进行安装openssl,指定版本 conda install openssl1.1.1s 于此同时,如果你之前安装了uwsgi,会发现uwsgi会…

【linux】更换ollama的deepseek模型默认安装路径

【linux】更换ollama的deepseek模型默认安装路径 文章目录 【linux】更换ollama的deepseek模型默认安装路径Ollama 默认安装路径及模型存储路径迁移ollama模型到新的路径1.创建新的模型存储目录2.停止ollama3.迁移现有模型4.修改 Ollama 服务配置5.重启ollama6.验证迁移是否成功…

springboot404-基于Java的校园礼服租赁系统(源码+数据库+纯前后端分离+部署讲解等)

💕💕作者: 爱笑学姐 💕💕个人简介:十年Java,Python美女程序员一枚,精通计算机专业前后端各类框架。 💕💕各类成品Java毕设 。javaweb,ssm&#xf…