0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

GRU模型实战训练 智能决策更精准

恩智浦MCU加油站 来源:恩智浦MCU加油站 2024-06-13 09:22 次阅读

上一期文章带大家认识了一个名为GRU的新朋友, GRU本身自带处理时序数据的属性,特别擅长对于时间序列的识别和检测(例如音频传感器信号等)。GRU其实是RNN模型的一个衍生形式,巧妙地设计了两个门控单元:reset门和更新门。reset门负责针对历史遗留的状态进行重置,丢弃掉无用信息;更新门负责对历史状态进行更新,将新的输入与历史数据集进行整合。通过模型训练,让模型能够自动调整这两个门控单元的状态,以期达到历史数据与最新数据和谐共存的目的。

理论知识掌握了,下面就来看看如何训练一个GRU模型吧。

训练平台选用Keras,请提前自行安装Keras开发工具。直接上代码,首先是数据导入部分,我们直接使用mnist手写字体数据集:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model


# 准备数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

模型构建与训练:

# 构建GRU模型
model = Sequential()
model.add(GRU(128, input_shape=(28, 28), stateful=False, unroll=False))
model.add(Dense(10, activation='softmax'))


# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


# 模型训练
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test))

这里,眼尖的伙伴应该是注意到了,GRU模型构建的时候,有两个参数,分别是stateful以及unroll,这两个参数是什么意思呢?

GRU层的stateful和unroll是两个重要的参数,它们对GRU模型的行为和性能有着重要影响:

stateful参数:默认情况下,stateful参数为False。当stateful设置为True时,表示在处理连续的数据时,GRU层的状态会被保留并传递到下一个时间步,而不是每个batch都重置状态。这对于处理时间序列数据时非常有用,例如在处理长序列时,可以保持模型的状态信息,而不是在每个batch之间重置。需要注意的是,在使用stateful时,您需要手动管理状态的重置。

unroll参数:默认情况下,unroll参数为False。当unroll设置为True时,表示在计算时会展开RNN的循环,这样可以提高计算性能,但会增加内存消耗。通常情况下,对于较短的序列,unroll设置为True可以提高计算速度,但对于较长的序列,可能会导致内存消耗过大。

通过合理设置stateful和unroll参数,可以根据具体的数据和模型需求来平衡模型的状态管理和计算性能。而我们这里用到的mnist数据集实际上并不是时间序列数据,而只是将其当作一个时序数据集来用。因此,每个batch之间实际上是没有显示的前后关系的,不建议使用stateful。而是每一个batch之后都要将其状态清零。即stateful=False。而unroll参数,大家就可以自行测试了。

模型评估与转换:

# 模型评估
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])


# 保存模型
model.save("mnist_gru_model.h5")


# 加载模型并转换
converter = tf.lite.TFLiteConverter.from_keras_model(load_model("mnist_gru_model.h5"))
tflite_model = converter.convert()


# 保存tflite格式模型
with open('mnist_gru_model.tflite', 'wb') as f:
    f.write(tflite_model)



便写好程序后,运行等待训练完毕,可以看到经过10个epoch之后,模型即达到了98.57%的测试精度:

44c1e04e-291f-11ef-91d2-92fbcf53809c.png

来看看最终的模型样子,参数stateful=False,unroll=True:

44e91506-291f-11ef-91d2-92fbcf53809c.png

这里,我们就会发现,模型的输入好像被拆分成了很多份,这是因为我们指定了输入是28*28。第一个28表示有28个时间步,后面的28则表示每一个时间步的维度。这里的时间步,指代的就是历史的数据。

现在,GRU模型训练就全部介绍完毕了,对于机器学习深度学习感兴趣的伙伴们,不妨亲自动手尝试一下,搭建并训练一个属于自己的GRU模型吧!

希望每一位探索者都能在机器学习的道路上不断前行,收获满满的知识和成果!

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • Gru
    Gru
    +关注

    关注

    0

    文章

    12

    浏览量

    7472
  • 机器学习
    +关注

    关注

    66

    文章

    8373

    浏览量

    132389
  • rnn
    rnn
    +关注

    关注

    0

    文章

    88

    浏览量

    6872

原文标题:GRU模型实战训练,智能决策更精准!

文章出处:【微信号:NXP_SMART_HARDWARE,微信公众号:恩智浦MCU加油站】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    ADC128S022怎么能更精准的去采样?

    问题1 连续转换模式具体功能 问题2 怎么能更精准的去采样
    发表于 11-15 06:02

    如何训练自己的LLM模型

    训练自己的大型语言模型(LLM)是一个复杂且资源密集的过程,涉及到大量的数据、计算资源和专业知识。以下是训练LLM模型的一般步骤,以及一些关键考虑因素: 定义目标和需求 : 确定你的L
    的头像 发表于 11-08 09:30 318次阅读

    谷东科技民航维修智能决策模型荣获华为昇腾技术认证

    经过华为专业评测,谷东科技民航维修智能决策模型1.0成功与华为Atlas 800T A2训练服务器完成并通过了相互兼容性测试认证,正式荣获华为昇腾技术认证,被授予Ascend Com
    的头像 发表于 09-30 15:22 385次阅读

    大语言模型的预训练

    随着人工智能技术的飞速发展,自然语言处理(NLP)作为人工智能领域的一个重要分支,取得了显著的进步。其中,大语言模型(Large Language Model, LLM)凭借其强大的语言理解和生成
    的头像 发表于 07-11 10:11 378次阅读

    人脸识别模型训练流程

    人脸识别模型训练流程是计算机视觉领域中的一项重要技术。本文将详细介绍人脸识别模型训练流程,包括数据准备、模型选择、
    的头像 发表于 07-04 09:19 812次阅读

    人脸识别模型训练失败原因有哪些

    人脸识别模型训练失败的原因有很多,以下是一些常见的原因及其解决方案: 数据集质量问题 数据集是训练人脸识别模型的基础。如果数据集存在质量问题,将直接影响
    的头像 发表于 07-04 09:17 530次阅读

    人脸识别模型训练是什么意思

    人脸识别模型训练是指通过大量的人脸数据,使用机器学习或深度学习算法,训练出一个能够识别和分类人脸的模型。这个模型可以应用于各种场景,如安防监
    的头像 发表于 07-04 09:16 460次阅读

    训练模型的基本原理和应用

    训练模型(Pre-trained Model)是深度学习和机器学习领域中的一个重要概念,尤其是在自然语言处理(NLP)和计算机视觉(CV)等领域中得到了广泛应用。预训练模型指的是在大
    的头像 发表于 07-03 18:20 2299次阅读

    深度学习模型训练过程详解

    深度学习模型训练是一个复杂且关键的过程,它涉及大量的数据、计算资源和精心设计的算法。训练一个深度学习模型,本质上是通过优化算法调整模型参数,
    的头像 发表于 07-01 16:13 1064次阅读

    深入GRU:解锁模型测试新维度

    之前带大家一起使用Keras训练了一个GRU模型,并使用mnist的手写字体数据集进行了验证。本期小编将继续带来一篇扩展,即GRU模型的测试
    的头像 发表于 06-27 09:36 1149次阅读
    深入<b class='flag-5'>GRU</b>:解锁<b class='flag-5'>模型</b>测试新维度

    GRU是什么?GRU模型如何让你的神经网络更聪明 掌握时间 掌握未来

    大家平时经常听到的GRU是什么呢? 首先来认识下CNN,CNN指代卷积神经网络(Convolutional Neural Network),这是一种在人工智能和机器学习领域中常用的神经网络架构,特别
    发表于 06-13 11:42 1434次阅读
    <b class='flag-5'>GRU</b>是什么?<b class='flag-5'>GRU</b><b class='flag-5'>模型</b>如何让你的神经网络更聪明 掌握时间 掌握未来

    【大语言模型:原理与工程实践】大语言模型的应用

    ,它通过抽象思考和逻辑推理,协助我们应对复杂的决策。 相应地,我们设计了两类任务来检验大语言模型的能力。一类是感性的、无需理性能力的任务,类似于人类的系统1,如情感分析和抽取式问答等。大语言模型在这
    发表于 05-07 17:21

    【大语言模型:原理与工程实践】大语言模型的预训练

    大语言模型的核心特点在于其庞大的参数量,这赋予了模型强大的学习容量,使其无需依赖微调即可适应各种下游任务,而更倾向于培养通用的处理能力。然而,随着学习容量的增加,对预训练数据的需求也相应
    发表于 05-07 17:10

    GPS信号探测器:让定位更精准,让生活更美好

    深圳特信屏蔽器|GPS信号探测器:让定位更精准,让生活更美好
    的头像 发表于 04-25 08:51 557次阅读

    谷歌模型训练软件有哪些?谷歌模型训练软件哪个好?

    谷歌在模型训练方面提供了一些强大的软件工具和平台。以下是几个常用的谷歌模型训练软件及其特点。
    的头像 发表于 03-01 16:24 806次阅读