0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

GRU模型实战训练 智能决策更精准

恩智浦MCU加油站 来源:恩智浦MCU加油站 2024-06-13 09:22 次阅读

上一期文章带大家认识了一个名为GRU的新朋友, GRU本身自带处理时序数据的属性,特别擅长对于时间序列的识别和检测(例如音频传感器信号等)。GRU其实是RNN模型的一个衍生形式,巧妙地设计了两个门控单元:reset门和更新门。reset门负责针对历史遗留的状态进行重置,丢弃掉无用信息;更新门负责对历史状态进行更新,将新的输入与历史数据集进行整合。通过模型训练,让模型能够自动调整这两个门控单元的状态,以期达到历史数据与最新数据和谐共存的目的。

理论知识掌握了,下面就来看看如何训练一个GRU模型吧。

训练平台选用Keras,请提前自行安装Keras开发工具。直接上代码,首先是数据导入部分,我们直接使用mnist手写字体数据集:

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model


# 准备数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

模型构建与训练:

# 构建GRU模型
model = Sequential()
model.add(GRU(128, input_shape=(28, 28), stateful=False, unroll=False))
model.add(Dense(10, activation='softmax'))


# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])


# 模型训练
model.fit(x_train, y_train, batch_size=128, epochs=10, validation_data=(x_test, y_test))

这里,眼尖的伙伴应该是注意到了,GRU模型构建的时候,有两个参数,分别是stateful以及unroll,这两个参数是什么意思呢?

GRU层的stateful和unroll是两个重要的参数,它们对GRU模型的行为和性能有着重要影响:

stateful参数:默认情况下,stateful参数为False。当stateful设置为True时,表示在处理连续的数据时,GRU层的状态会被保留并传递到下一个时间步,而不是每个batch都重置状态。这对于处理时间序列数据时非常有用,例如在处理长序列时,可以保持模型的状态信息,而不是在每个batch之间重置。需要注意的是,在使用stateful时,您需要手动管理状态的重置。

unroll参数:默认情况下,unroll参数为False。当unroll设置为True时,表示在计算时会展开RNN的循环,这样可以提高计算性能,但会增加内存消耗。通常情况下,对于较短的序列,unroll设置为True可以提高计算速度,但对于较长的序列,可能会导致内存消耗过大。

通过合理设置stateful和unroll参数,可以根据具体的数据和模型需求来平衡模型的状态管理和计算性能。而我们这里用到的mnist数据集实际上并不是时间序列数据,而只是将其当作一个时序数据集来用。因此,每个batch之间实际上是没有显示的前后关系的,不建议使用stateful。而是每一个batch之后都要将其状态清零。即stateful=False。而unroll参数,大家就可以自行测试了。

模型评估与转换:

# 模型评估
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])


# 保存模型
model.save("mnist_gru_model.h5")


# 加载模型并转换
converter = tf.lite.TFLiteConverter.from_keras_model(load_model("mnist_gru_model.h5"))
tflite_model = converter.convert()


# 保存tflite格式模型
with open('mnist_gru_model.tflite', 'wb') as f:
    f.write(tflite_model)



便写好程序后,运行等待训练完毕,可以看到经过10个epoch之后,模型即达到了98.57%的测试精度:

44c1e04e-291f-11ef-91d2-92fbcf53809c.png

来看看最终的模型样子,参数stateful=False,unroll=True:

44e91506-291f-11ef-91d2-92fbcf53809c.png

这里,我们就会发现,模型的输入好像被拆分成了很多份,这是因为我们指定了输入是28*28。第一个28表示有28个时间步,后面的28则表示每一个时间步的维度。这里的时间步,指代的就是历史的数据。

现在,GRU模型训练就全部介绍完毕了,对于机器学习深度学习感兴趣的伙伴们,不妨亲自动手尝试一下,搭建并训练一个属于自己的GRU模型吧!

希望每一位探索者都能在机器学习的道路上不断前行,收获满满的知识和成果!

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • Gru
    Gru
    +关注

    关注

    0

    文章

    12

    浏览量

    7449
  • 机器学习
    +关注

    关注

    66

    文章

    8189

    浏览量

    131228
  • rnn
    rnn
    +关注

    关注

    0

    文章

    68

    浏览量

    6833

原文标题:GRU模型实战训练,智能决策更精准!

文章出处:【微信号:NXP_SMART_HARDWARE,微信公众号:恩智浦MCU加油站】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    深入GRU:解锁模型测试新维度

    之前带大家一起使用Keras训练了一个GRU模型,并使用mnist的手写字体数据集进行了验证。本期小编将继续带来一篇扩展,即GRU模型的测试
    的头像 发表于 06-27 09:36 550次阅读
    深入<b class='flag-5'>GRU</b>:解锁<b class='flag-5'>模型</b>测试新维度

    请问电脑端Pytorch训练模型如何转化为能在ESP32S3平台运行的模型

    由题目, 电脑端Pytorch训练模型如何转化为能在ESP32S3平台运行的模型? 如何把这个Pytorch模型烧录到ESP32S3上去?
    发表于 06-27 06:06

    GRU是什么?GRU模型如何让你的神经网络更聪明 掌握时间 掌握未来

    大家平时经常听到的GRU是什么呢? 首先来认识下CNN,CNN指代卷积神经网络(Convolutional Neural Network),这是一种在人工智能和机器学习领域中常用的神经网络架构,特别
    发表于 06-13 11:42 281次阅读
    <b class='flag-5'>GRU</b>是什么?<b class='flag-5'>GRU</b><b class='flag-5'>模型</b>如何让你的神经网络更聪明 掌握时间 掌握未来

    谈谈 十折交叉验证训练模型

    谈谈 十折交叉验证训练模型
    的头像 发表于 05-15 09:30 318次阅读

    微软在天气预报领域突破,新AI模型精准预测未来30天

    微软在天气预报领域取得显著成果。其Start团队成功研发了一种全新AI模型,能够更精准地预测未来30天的天气状况。
    的头像 发表于 05-10 11:23 494次阅读

    【大语言模型:原理与工程实践】大语言模型的应用

    ,它通过抽象思考和逻辑推理,协助我们应对复杂的决策。 相应地,我们设计了两类任务来检验大语言模型的能力。一类是感性的、无需理性能力的任务,类似于人类的系统1,如情感分析和抽取式问答等。大语言模型在这
    发表于 05-07 17:21

    【大语言模型:原理与工程实践】大语言模型的预训练

    大语言模型的核心特点在于其庞大的参数量,这赋予了模型强大的学习容量,使其无需依赖微调即可适应各种下游任务,而更倾向于培养通用的处理能力。然而,随着学习容量的增加,对预训练数据的需求也相应
    发表于 05-07 17:10

    【大语言模型:原理与工程实践】核心技术综述

    其预训练和微调,直到模型的部署和性能评估。以下是对这些技术的综述: 模型架构: LLMs通常采用深层的神经网络架构,最常见的是Transformer网络,它包含多个自注意力层,能够捕捉输入数据中
    发表于 05-05 10:56

    【大语言模型:原理与工程实践】揭开大语言模型的面纱

    大语言模型(LLM)是人工智能领域的尖端技术,凭借庞大的参数量和卓越的语言理解能力赢得了广泛关注。它基于深度学习,利用神经网络框架来理解和生成自然语言文本。这些模型通过训练海量的文本数
    发表于 05-04 23:55

    GPS信号探测器:让定位更精准,让生活更美好

    深圳特信屏蔽器|GPS信号探测器:让定位更精准,让生活更美好
    的头像 发表于 04-25 08:51 201次阅读

    谷歌模型训练软件有哪些?谷歌模型训练软件哪个好?

    谷歌在模型训练方面提供了一些强大的软件工具和平台。以下是几个常用的谷歌模型训练软件及其特点。
    的头像 发表于 03-01 16:24 454次阅读

    防范人工智能决策带来风险

    人工智能的机器学习方法可以通过分析数据中的模式和规律,自动训练和模拟模型,更加精准地预测潜在问题和风险的发生概率、复杂性与严重程度,事前预估不同政策方案实施后果,从而帮助
    的头像 发表于 09-27 16:46 441次阅读

    训练大语言模型带来的硬件挑战

    生成式AI和大语言模型(LLM)正在以难以置信的方式吸引全世界的目光,本文简要介绍了大语言模型训练这些模型带来的硬件挑战,以及GPU和网络行业如何针对
    的头像 发表于 09-01 17:14 1228次阅读
    <b class='flag-5'>训练</b>大语言<b class='flag-5'>模型</b>带来的硬件挑战

    人工智能训练师是什么

    人工智能训练师指的是具有相关专业能力的人士,在人工智能领域里,他们负责训练机器学习模型。与传统的计算机科学相比,机器学习是一个相对新的领域,
    的头像 发表于 08-13 14:17 1562次阅读

    训练好的ai模型导入cubemx不成功怎么解决?

    训练好的ai模型导入cubemx不成功咋办,试了好几个模型压缩了也不行,ram占用过大,有无解决方案?
    发表于 08-04 09:16