0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

那些年在pytorch上踩过的坑

jf_78858299 来源:天宏NLP 作者:tianhongzxy 2023-02-22 14:18 次阅读

今天又发现了一个pytorch的小坑,给大家分享一下。手上两份同一模型的代码,一份用tensorflow写的,另一份是我拿pytorch写的,模型架构一模一样,预处理数据的逻辑也一模一样,测试发现模型推理的速度也差不多。一份预处理代码是为pytorch模型写的,用到的库是torch,另一份是为tensorflow写的,用到的是numpy。在训练时,每个epoch耗时居然差距非常大,pytorch的代码在140w条数据上训练每轮耗时约45min,而tensorflow版的代码耗时仅约12min。

我把代码看了又看,百思不得其解,预处理的代码比较复杂,都包含两个for循环,pytorch版代码我把更多的预处理步骤放到了Dataset里,这样训练时加载每个batch后,再要处理的步骤就更少了,速度也应该更快,而tensorflow版代码的for循环里预处理的步骤明明更多,怎么会速度比我的代码还快呢?然而,经过我的测试发现,从加载每个batch的数据进来开始,经过预处理,直到输入到模型做计算前,两者的耗时差了约7~8倍。最后发现问题出在对pytorch的tensor进行了频繁的索引操作。

下面做个实验给大家直观体验一下,对tensor做索引和对array做索引的速度差距有多大,tensorarray都是大小(1000x1000)的二维数组。

Pytorch(version==1.4.1)索引1000000次耗时:3.51秒

图片

Numpy索引1000000次耗时:0.43秒

图片

我还特意对比了一下对TensorFlow的tensor做索引的耗时

TensorFlow(version==2.1.0)索引1000000次耗时:118.89秒

图片

由此可见tensorarray的索引速度至少差距在10倍,不过这也在情理之中,毕竟tensor要比array“重”得多。因此在使用pytorch和tensorflow时,频繁需要索引的操作一定要先把tensor转换为numpy.array来做!

除此之外,与其对二维数组进行索引,不如将其展平为一维数组,算上展平的时间,速度还会有不少提升。

Pytorch从3.51秒降到了1.94秒

图片

Numpy从0.43秒降到了0.29秒

图片

如果在训练和数据预处理过程中发现自己的代码跑起来速度非常慢,记得看一看有没有对tensor做太多次索引,如果有的话,要把它转为numpy.array,还有,尽量把二维、三维的索引变成一维的索引,这些都能加快你训练模型的速度。

PS:最后我的代码终于训练一轮也只需要不到12min了,后来又找了点加速的办法,把训练一轮的时间控制到了9min以内,这些就放在以后再写吧~

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 代码
    +关注

    关注

    30

    文章

    4741

    浏览量

    68324
  • tensorflow
    +关注

    关注

    13

    文章

    328

    浏览量

    60490
  • pytorch
    +关注

    关注

    2

    文章

    803

    浏览量

    13142
收藏 人收藏

    评论

    相关推荐

    使用STM32采集电池电压那些

    本文来解析一个盆友在使用STM32采集电池电压。以STM32F4 的ADC属于逐次逼近SAR 型ADC为例进行分析,参考STM32F405xxDatasheet,对于如何编写ADC程序就不做描述了。
    发表于 03-01 07:39

    开发STM32 USB HID

    记录一下 开发STM32 USB HID一、前言二、代码配置一、前言MCU: STM32F103C8T6CubeMX: STM32CubeMX 5.3.0二、代码配置引脚配置时钟树配置我
    发表于 08-24 07:15

    使用树莓派搭建stm32开发环境以及碰到的问题

    使用树莓派搭建stm32开发环境了很多,下面主要是记录一下,以及碰到的问题。##开发方式的选择1.使用Eclipse+GDB+O
    发表于 08-24 07:47

    Linux学习过程与如何解决

    Linux记录记录Linux学习过程与如何解决
    发表于 11-04 08:44

    移植debian系统

    基本的linux系统,板子的交叉编译器是arm-linux-gnueabihf-gcc,这给我带来了不少的麻烦,以至于想重新移植一下debian系统。ok,转入正题,说说这两天我吧。首先...
    发表于 12-14 08:42

    使用MDK5时出现的一些error分享

    使用MDK5时出现的一些error分享
    发表于 12-17 07:49

    关于RK1808板子调试过程记录

    关于RK1808板子调试过程记录
    发表于 02-16 06:38

    STM32G070CB cubemx串口调试哪些

    使用G070CB时写的中断程序是怎样的?STM32G070CB cubemx串口调试哪些呢?
    发表于 02-18 06:08

    专访技术创业工程师吴才泽:感恩这些年

    本期采访对象技术创业工程师吴才泽,这些年从工程师到创业那些呢?
    发表于 11-25 16:53 3361次阅读

    使用STM32采集电池电压资料下载

    电子发烧友网为你提供使用STM32采集电池电压资料下载的电子资料下载,更有其他相关的电路图、源代码、课件教程、中文资料、英文资料、参考设计、用户指南、解决方案等资料,希望可以帮助到广大的电子工程师们。
    发表于 04-05 08:49 73次下载
    使用STM32采集电池电压<b class='flag-5'>踩</b><b class='flag-5'>过</b>的<b class='flag-5'>坑</b>资料下载

    嵌入式Linux记录

    Linux记录记录Linux学习过程与如何解决
    发表于 11-01 17:21 10次下载
    嵌入式Linux<b class='flag-5'>踩</b><b class='flag-5'>坑</b>记录

    Arduino-IDE配置ESP32-CAM开发环境那些

    Arduino-IDE配置ESP32-CAM开发环境那些
    发表于 11-30 18:36 24次下载
    Arduino-IDE配置ESP32-CAM开发环境<b class='flag-5'>踩</b><b class='flag-5'>过</b>的<b class='flag-5'>那些</b><b class='flag-5'>坑</b>

    推挽电路的,你没?

    推挽电路的,你没?
    的头像 发表于 11-24 16:25 1058次阅读
    推挽电路的<b class='flag-5'>坑</b>,你<b class='flag-5'>踩</b><b class='flag-5'>过</b>没?

    关于图像传感器图像质量的四大误区!你几个

    关于图像传感器图像质量的四大误区!你几个
    的头像 发表于 11-27 16:56 419次阅读
    关于图像传感器图像质量的四大误区!你<b class='flag-5'>踩</b><b class='flag-5'>过</b>几个<b class='flag-5'>坑</b>?

    反相输入放大器的,你没有?

    反相输入放大器的,你没有?
    的头像 发表于 12-06 15:35 583次阅读
    反相输入放大器的<b class='flag-5'>坑</b>,你<b class='flag-5'>踩</b><b class='flag-5'>过</b>没有?