0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

利用TensorFlow.js,D3.js 和 Web 的力量使训练模型的过程可视化

Tensorflowers 来源:未知 作者:李倩 2018-08-08 14:24 次阅读

在这篇文章中,我们将利用 TensorFlow.js,D3.js 和 Web 的力量使训练模型的过程可视化,以预测棒球数据中的坏球(蓝色区域)和好球(橙色区域)。 随着我们的进展,我们将模型在整个训练过程中理解的打击区域可视化。您可以通过访问此 Observable 笔记本在浏览器中运行此模型。

注:Observable链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d

如果你不熟悉棒球的击球区,这里有一篇详细的文章。

上面的 GIF 可视化神经网络学习调用坏球(蓝色区域)和好球(橙色区域)在每个训练步骤之后,热图会根据模型的预测进行更新

使用 Observable 直接在浏览器中运行此模型。

注:文章链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d

体育运动中的高级指标

当今的职业体育环境中充斥着大量的数据。这些数据被团队,业余爱好者和粉丝应用于各种用例中。感谢像 TensorFlow 这样的框架 - 这些数据集已准备好应用于机器学习

美国职业棒球大联盟先进媒体(MLBAM)的 PITCHf/x

美国职业棒球大联盟先进媒体(MLBAM)发布了一个可供公众研究的大型数据集。该数据集包含有关过去几年在美国职业棒球大联盟比赛中投掷的投球的传感器信息。 利用这个数据集,我们已编写了一个包含 5,000 个样本的训练集(2,500 个坏球和 2,500 个好球)。

以下是训练数据中前几个字段的示例:

注:示例链接

https://gist.github.com/nkreeger/01b5386b522b0cd1f22bc864320f3084#file-baseball-training-data-sample-csv

以下是针对打击区域绘制的训练数据的样子。蓝点标记为坏球,橙点标记为好球(此为大联盟裁判员称谓):

利用 TensorFlow.js 构建模型

TensorFlow.js 将机器学习引入 JavaScript 和 Web。 我们将利用这个很棒的框架来构建一个深度神经网络模型。这个模型将能够按大联盟裁判的精准度来称呼好球和坏球。

输入 Input

该模型在 PITCHf / x 的以下字段中进行了训练:

协调球越过本垒的位置('px'和'pz')。

击球手站在垒的哪一侧。

击球区(击球手的躯干)的高度,以英尺为单位。

击球区底部的高度(击球手的膝盖)以英尺为单位。

裁判所称的投球(好球或坏球)的实际标签

结构 Architecture

该模型将通过使用 TensorFlow.js 图层 API 定义。Layers API 基于 Keras,对以前使用过该框架的人来说应该很熟悉:

1const model = tf.sequential();

2

3// Two fully connected layers with dropout between each:

4model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));

5model.add(tf.layers.dropout({rate: 0.01}));

6model.add(tf.layers.dense({units: 16, activation: 'relu'}));

7model.add(tf.layers.dropout({rate: 0.01}));

8

9// Only two classes: "strike" and "ball":

10model.add(tf.layers.dense({units: 2, activation: 'softmax'}));

11

12model.compile({

13optimizer: tf.train.adam(0.01),

14loss: 'categoricalCrossentropy',

15metrics: ['accuracy']

16});

加载和准备数据

精选的训练集可通过GitHub gist 获得。需要下载此数据集才能开始将 CSV 数据转换为 TensorFlow.js 用于训练的格式。

注:GitHub gist 链接

https://gist.github.com/nkreeger/43edc6e6daecc2cb02a2dd3293a08f29

1const data = [];

2csvData.forEach((values) => {

3// 'logit' data uses the 5 fields:

4const x = [];

5x.push(parseFloat(values.px));

6x.push(parseFloat(values.pz));

7x.push(parseFloat(values.sz_top));

8x.push(parseFloat(values.sz_bot));

9x.push(parseFloat(values.left_handed_batter));

10// The label is simply 'is strike' or 'is ball':

11const y = parseInt(values.is_strike, 10);

12data.push({x: x, y: y});

13});

14// Shuffle the contents to ensure the model does not always train on the same

15// sequence of pitch data:

16tf.util.shuffle(data);

解析 CSV 数据后,需要将 JS 类型转换为 Tensor 批次进行培训和评估。有关此过程的详细信息,请参阅代码实验室。TensorFlow.js 团队正在开发一种新的 Data API,以便将来更容易获取。

注:代码实验室

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#batches

训练模型

让我们把这一切都整合在一起吧。定义了模型,准备好了训练数据,现在我们已经准备好开始训练了。以下异步方法训练一批训练样本并更新热图:

1// Trains and reports loss+accuracy for one batch of training data:

2async function trainBatch(index) {

3const history = await model.fit(batches[index].x, batches[index].y, {

4epochs: 1,

5shuffle: false,

6validationData: [batches[index].x, batches[index].y],

7batchSize: CONSTANTS.BATCH_SIZE

8});

9

10// Don't block the UI frame by using tf.nextFrame()

11await tf.nextFrame();

12updateHeatmap();

13await tf.nextFrame();

14}

可视化模型的准确性

使用来自均匀放置在本垒板上方的 4 x 4 英尺栅格的预测矩阵来构建热图。在每个训练步骤之后将该矩阵传递到模型中以检查模型的准确度。使用 D3 库将该预测的结果呈现为热图。

构建预测矩阵

热图中使用的预测矩阵从本垒板的中间开始,向左和向右各延伸 2 英尺。它的范围也从本垒板的底部到 4 英尺高。击打区样本位于本垒板上方 1.5 至 3.5 英尺之间。下图有助于让这些 2d 窗格可视化:

该视觉显示了打击区域和预测矩阵与本垒板和游戏区域相关的位置

将预测矩阵与模型一起使用

每个批次在模型中训练之后,预测矩阵被传递到模型中用以请求矩阵中的好球或坏球预测:

1function predictZone() {

2const predictions = model.predictOnBatch(predictionMatrix.data);

3const values = predictions.dataSync();

4

5// Sort each value so the higher prediction is the first element in the array:

6const results = [];

7let index = 0;

8for (let i = 0; i < values.length; i++) {    

9let list = [];

10list.push({value: values[index++], strike: 0});

11list.push({value: values[index++], strike: 1});

12list = list.sort((a, b) => b.value - a.value);

13results.push(list);

14}

15return results;

16}

热图与 D3

现在可以使用 D3 显示预测结果。 来自 50x50 网格中的每一个元素将在 SVG 中呈现为 10px x 10px 的矩形。每个矩形的颜色取决于预测结果(好球或者坏球)以及模型对该结果的确定程度(范围从 50%-100%)。 以下代码段显示了如何从 D3 svg 矩形分组更新数据:

1function updateHeatmap() {

2rects.data(generateHeatmapData());

3rects

4.attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })

5.attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })

6.attr('width', CONSTANTS.HEATMAP_SIZE)

7.attr('height', CONSTANTS.HEATMAP_SIZE)

8.style('fill', (coord) => {

9if (coord.strike) {

10return strikeColorScale(coord.value);

11} else {

12return ballColorScale(coord.value);

13}

14});

15}

有关使用 D3 绘制热图的完整详细信息,请参阅此部分。

注:此部分链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#colorDomain

总结

网络上有许多令人惊叹的第三方库和工具,可用于创建视觉效果。将这些与机器学习的强大功能与 TensorFlow.js 相结合,开发人员能够创建一些非常新奇有趣的演示。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4732

    浏览量

    100373
  • 机器学习
    +关注

    关注

    66

    文章

    8337

    浏览量

    132255
  • tensorflow
    +关注

    关注

    13

    文章

    328

    浏览量

    60458

原文标题:棒球比赛中是好球还是坏球?TensorFlow.js 已经知道

文章出处:【微信号:tensorflowers,微信公众号:Tensorflowers】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    Tensorflow之Tensorboard的可视化使用

    TF之Tensorboard:Tensorflow之Tensorboard可视化使用之详细攻略
    发表于 12-27 10:05

    Keras可视化神经网络架构的4种方法

    我们在使用卷积神经网络或递归神经网络或其他变体时,通常都希望对模型的架构可以进行可视化的查看,因为这样我们可以 在定义和训练多个模型时,比较不同的层以及它们放置的顺序对结果的影响。还有
    发表于 11-02 14:55

    keras可视化介绍

    keras可视化可以帮助我们直观的查看所搭建的模型拓扑结构,以及模型训练过程,方便我们优化模型
    发表于 08-18 07:53

    利用PADS实现3D可视化

    本文给出了利用PADS实现3D可视化的 具体过程,并对PADS和3D技术进行了必要的说明。
    发表于 10-10 16:03 477次下载
    <b class='flag-5'>利用</b>PADS实现<b class='flag-5'>3D</b><b class='flag-5'>可视化</b>

    TensorFlow发表推文正式发布TensorFlow v1.9

    其中有两个案例受到了大家的广泛关注,这个项目是通过 Colab 在 tf.keras 中训练模型,并通过TensorFlow.js 在浏览器中运行;最近在 JS 社区中,对这些相关项目
    的头像 发表于 07-16 10:23 3067次阅读

    如何使用TensorFlow.js构建这一系统

    TensorFlow.js团队一直在进行有趣的基于浏览器的实验,以使人们熟悉机器学习的概念,并鼓励他们将其用作您自己项目的构建块。对于那些不熟悉的人来说,TensorFlow.js是一个开源库,允许
    的头像 发表于 08-19 08:55 3522次阅读

    基于tensorflow.js设计、训练面向web的神经网络模型的经验

    NVIDIA显卡。tensorflow.js在底层使用了WebGL加速,所以在浏览器中训练模型的一个好处是可以利用AMD显卡。另外,在浏览器中训练
    的头像 发表于 10-18 09:43 4055次阅读

    指引入门d3.js的门径,如何基于基本原则创建可视化

    图表不过是内有形状的矩形。d3提供了方法,通过操作图形记号或创建自己的形状来定义你自己的可视化表示。d3使加入视觉交互和声明可视化行为变得容
    的头像 发表于 11-08 09:03 2962次阅读

    TensorFlow.js制作了一个仅用 200 余行代码的项目

    我们先来看一下运行的效果。下图中,上半部分是原始视频,下半部分是使用 TensorFlow.js 对人像进行消除后的视频。可以看到,除了偶尔会在边缘处留有残影之外,整体效果还是很不错的。
    的头像 发表于 05-11 18:08 5561次阅读

    Danfo.js提供高性能、直观易用的数据结构,支持结构数据的操作和处理

    我们的愿景一致,本质上也符合 TensorFlow.js 团队向 Web 引入 ML 的目标。Numpy 和 Pandas 等开源库全面革新了 Python 中数据操作的便利性。因此很多工具都围绕它们构
    的头像 发表于 09-23 18:21 5214次阅读

    如何基于 ES6 的 JavaScript 进行 TensorFlow.js 的开发

    从头开发、训练和部署模型,也可以用来运行已有的 Python 版 TensorFlow 模型,或者基于现有的模型进行继续
    的头像 发表于 10-31 11:16 3067次阅读

    关于Web3D线上数字可视化技术的应用

    Web3D拆装交互爆炸演示三维数字可视化时代融入于我们生活当中。在虚拟现实行业的热门中,Web3D线上数字可视化技术也在是跟随者行业升级和实现科技和信息
    发表于 01-16 10:46 1197次阅读
    关于<b class='flag-5'>Web3D</b>线上数字<b class='flag-5'>可视化</b>技术的应用

    浅谈工业3D可视化建模的特点

    智能3D设备是在工业搭建的3D建模和三维可视化基础上之上构建的一个机遇Web3D的虚拟工业,其运用物理网、云计算等现代信息技术,商迪3D运用
    发表于 04-09 10:23 2045次阅读

    工厂3D可视化模型检测技术管理助力工业绿色发展

    对于工业绿色发展而言,3D可视化模型技术在工厂领域中检测管理发挥着重要的作用。 商迪3D结合3D可视化
    的头像 发表于 04-26 17:19 1735次阅读

    FUXA基于Web过程可视化软件案例

    FUXA——基于Web过程可视化软件
    发表于 04-24 18:32 0次下载