0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

如何将Flax/JAX模型转换为TFLite并在原生Android应用中运行呢

Tensorflowers 来源:TensorFlow 作者:TensorFlow 2022-11-02 10:13 次阅读

在我们之前发布的文章《一个新的 TensorFlow Lite 示例应用:棋盘游戏》中,展示了如何使用 TensorFlow 和 TensorFlow Agents 来训练强化学习 (RL) agent,使其玩一个简单的棋盘游戏“Plane Strike”。我们还将训练后的模型转换为 TensorFlow Lite,然后将其部署到功能完备的 Android 应用中。本文,我们将演示一种全新路径:使用 Flax/JAX 训练相同的强化学习 agent,然后将其部署到我们之前构建的同一款 Android 应用中。

简单回顾一下游戏规则:我们基于强化学习的 agent 需要根据真人玩家的棋盘位置预测击打位置,以便能早于真人玩家完成游戏。如需进一步了解游戏规则,请参阅我们之前发布的文章。

23754442-59d4-11ed-a3b6-dac502259ad0.gif

“Plane Strike”游戏演示

背景:JAX 和 TensorFlow

JAX 是一个与 NumPy 类似的内容库,由 Google Research 部门专为实现高性能计算而开发。JAX 使用 XLA 针对 GPU 和 TPU 优化的程序进行编译。

JAX

https://github.com/google/jax

XLA

https://tensorflow.google.cn/xla

TPU

https://cloud.google.com/tpu

而 Flax 则是在 JAX 基础上构建的一款热门神经网络库。研究人员一直在使用 JAX/Flax 来训练包含数亿万个参数的超大模型(如用于语言理解和生成的 PaLM,或者用于图像生成的 Imagen),以便充分利用现代硬件

Flax

https://github.com/google/flax

PaLM

https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html

Imagen

https://imagen.research.google/

如果您不熟悉 JAX 和 Flax,可以先从 JAX 101 教程和 Flax 入门示例开始。

JAX 101 教程

https://jax.readthedocs.io/en/latest/jax-101/index.html

Flax 入门示例

https://flax.readthedocs.io/en/latest/getting_started.html

2015 年底,TensorFlow 作为 Machine Learning (ML) 内容库问世,现已发展为一个丰富的生态系统,其中包含用于实现 ML 流水线生产化 (TFX)、数据可视化 (TensorBoard),和将 ML 模型部署到边缘设备 (TensorFlow Lite) 的工具,以及在网络浏览器上运行的装置,或能够执行 JavaScript (TensorFlow.js) 的任何装置。

TFX

https://tensorflow.google.cn/tfx

TensorBoard

https://tensorboard.dev/

TensorFlow Lite

https://tensorflow.google.cn/lite

TensorFlow.js

https://tensorflow.google.cn/js

在 JAX 或 Flax 中开发的模型也可以利用这一丰富的生态系统。方法是首先将此类模型转换为 TensorFlow SavedModel 格式,然后使用与它们在 TensorFlow 中原生开发相同的工具。

SavedModel

https://tensorflow.google.cn/guide/saved_model

如果您已经拥有经 JAX 训练的模型并希望立即进行部署,我们整合了一份资源列表供您参考:

视频 “使用 TensorFlow Serving 为 JAX 模型提供服务”,展示了如何使用 TensorFlow Serving 部署 JAX 模型。

https://youtu.be/I4dx7OI9FJQ?t=36

文章《借助 TensorFlow.js 在网络上使用 JAX》,对如何将 JAX 模型转换为 TFJS,并在网络应用中运行进行了详细讲解。

https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html

本篇文章演示了如何将 Flax/JAX 模型转换为 TFLite,并在原生 Android 应用中运行该模型。

总而言之,无论您的部署目标是服务器、网络还是移动设备,我们都会为您提供相应的帮助。

使用 Flax/JAX 实现游戏 agent

将目光转回到棋盘游戏。为了实现强化学习 agent,我们将会利用与之前相同的 OpenAI gym 环境。这次,我们将使用 Flax/JAX 训练相同的策略梯度模型。回想一下,在数学层面上策略梯度的定义是:

OpenAI gym

https://github.com/tensorflow/examples/tree/master/lite/examples/reinforcement_learning/ml/tf_and_jax/gym_planestrike/gym_planestrike/envs

23e88678-59d4-11ed-a3b6-dac502259ad0.png

其中:

T:每段的时步数,各段的时步数可能有所不同

st:时步上的状态 t

at:时步上的所选操作 t 指定状态s

πθ:参数为 θ 的策略

R(*):在指定策略下,收集到的奖励

我们定义了一个 3 层 MLP 作为策略网络,该网络可以预测 agent 的下一个击打位置。

class PolicyGradient(nn.Module):
  """Neural network to predict the next strike position."""


@nn.compact
  def __call__(self, x):
    dtype = jnp.float32
    x = x.reshape((x.shape[0], -1))
    x = nn.Dense(
        features=2 * common.BOARD_SIZE**2, name='hidden1', dtype=dtype)(
           x)
    x = nn.relu(x)
    x = nn.Dense(features=common.BOARD_SIZE**2, name='hidden2', dtype=dtype)(x)
    x = nn.relu(x)
    x = nn.Dense(features=common.BOARD_SIZE**2, name='logits', dtype=dtype)(x)
    policy_probabilities = nn.softmax(x)
    return policy_probabilities

在我们训练循环的每次迭代中,我们都会使用神经网络玩一局游戏、收集轨迹信息(游戏棋盘位置、采取的操作和奖励)、对奖励进行折扣,然后使用相应轨迹训练模型。

for i in tqdm(range(iterations)):
   predict_fn = functools.partial(run_inference, params)
   board_log, action_log, result_log = common.play_game(predict_fn)
   rewards = common.compute_rewards(result_log)
   optimizer, params, opt_state = train_step(optimizer, params, opt_state,
                                             board_log, action_log, rewards)

在 train_step() 方法中,我们首先会使用轨迹计算损失,然后使用 jax.grad() 计算梯度,最后,使用 Optax(用于 JAX 的梯度处理和优化库)来更新模型参数。

Optax

https://github.com/deepmind/optax

def compute_loss(logits, labels, rewards):
  one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2)
  loss = -jnp.mean(
      jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards))
  return loss


def train_step(model_optimizer, params, opt_state, game_board_log,
              predicted_action_log, action_result_log):
"""Run one training step."""

  def loss_fn(model_params):
    logits = run_inference(model_params, game_board_log)
    loss = compute_loss(logits, predicted_action_log, action_result_log)
    return loss

  def compute_grads(params):
    return jax.grad(loss_fn)(params)

  grads = compute_grads(params)
  updates, opt_state = model_optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  return model_optimizer, params, opt_state


@jax.jit
def run_inference(model_params, board):
  logits = PolicyGradient().apply({'params': model_params}, board)
  return logits

这就是训练循环。如下图所示,我们可以在 TensorBoard 中观察训练进度;其中,我们使代理指标“game_length”(完成游戏所需的步骤数)来跟踪进度:若 agent 变得更聪明,它便能以更少的步骤完成游戏。

23f8d758-59d4-11ed-a3b6-dac502259ad0.png

将 Flax/JAX 模型转换为

TensorFlow Lite 并与

Android 应用集成

完成模型训练后,我们使用 jax2tf(一款 TensorFlow-JAX 互操作工具),将 JAX 模型转换为 TensorFlow concrete function。最后一步是调用 TensorFlow Lite 转换器来将 concrete function 转换为 TFLite 模型。

jax2tf

https://github.com/google/jax/tree/main/jax/experimental/jax2tf

# Convert to tflite model
 model = PolicyGradient()
 jax_predict_fn = lambda input: model.apply({'params': params}, input)


 tf_predict = tf.function(
     jax2tf.convert(jax_predict_fn, enable_xla=False),
     input_signature=[
         tf.TensorSpec(
             shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],
             dtype=tf.float32,
             name='input')
     ],
     autograph=False,
 )


 converter = tf.lite.TFLiteConverter.from_concrete_functions(
     [tf_predict.get_concrete_function()], tf_predict)


 tflite_model = converter.convert()


 # Save the model
 with open(os.path.join(modeldir, 'planestrike.tflite'), 'wb') as f:
   f.write(tflite_model)

经 JAX 转换的 TFLite 模型与任何经 TensorFlow 训练的 TFLite 模型会有完全一致的行为。您可以使用 Netron 进行可视化:

242392fe-59d4-11ed-a3b6-dac502259ad0.png

使用 Netron 对 Flax/JAX 转换的 TFLite 模型进行可视化

我们可以使用与之前完全一样的 Java 代码来调用模型并获取预测结果。

convertBoardStateToByteBuffer(board);
tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i < probArray.length; i++) {
  int x = i / Constants.BOARD_SIZE;
  int y = i % Constants.BOARD_SIZE;
  if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
    agentStrikePosition = i;
    maxProb = probArray[i];
  }
}

总结

本文详细介绍了如何使用 Flax/JAX 训练简单的强化学习模型、利用 jax2tf 将其转换为 TensorFlow Lite,以及将转换后的模型集成到 Android 应用。

现在,您已经了解了如何使用 Flax/JAX 构建神经网络模型,以及如何利用强大的 TensorFlow 生态系统,在几乎任何您想要的位置部署模型。我们十分期待看到您使用 JAX 和 TensorFlow 构建出色应用!





审核编辑:刘清

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4762

    浏览量

    100517
  • TPU
    TPU
    +关注

    关注

    0

    文章

    138

    浏览量

    20691
  • MLP
    MLP
    +关注

    关注

    0

    文章

    57

    浏览量

    4226

原文标题:使用 JAX 构建强化学习 agent,并借助 TensorFlow Lite 将其部署到 Android 应用中

文章出处:【微信号:tensorflowers,微信公众号:Tensorflowers】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    使用电脑上tensorflow创建的模型转换为tflite格式了,导入后进度条反复出现0-100%变化,为什么?

    使用电脑上tensorflow创建的模型转换为tflite格式了,导入后,进度条反复出现0-100%变化,卡了一个晚上了还没分析好?
    发表于 03-19 06:20

    如何将采样位移转换为采样速度

    我是新手,在Labview编程如何将采样的位移转换为速度?求图解,谢谢
    发表于 04-25 14:56

    如何将秒数转换为时间字符串?

    请问如何将数值型秒数转换为时间字符串?比如3600s转换为01:00:00
    发表于 03-30 13:15

    如何将传统ANN转换为SNN?

    SNN和ANN的区别是什么?如何将传统ANN转换为SNN?
    发表于 09-28 06:15

    如何将触控芯片的IIC接口转换为USB接口

    CH554是什么?CH554如何实现数据转换如何将触控芯片的IIC接口转换为USB接口
    发表于 02-24 07:54

    EIQ onnx模型转换为tf-lite失败怎么解决?

    问题: 而我们需要您帮助我们回答这些问题:a) Dose eIQ(版本 2.7.12)支持 onnx 模型转换为 tflte 格式?(文件见附件)b) 找不到float16 的量化选项,你知道
    发表于 03-31 08:03

    如何在MIMXRT1064评估套件上部署tflite模型

    我有一个婴儿哭声检测 tflite (tensorflow lite) 文件,其中包含模型本身。我如何将模型部署到 MIMXRT1064-evk 以通过 MCUXpresso IDE
    发表于 04-06 06:24

    如何将DS_CNN_S.pb转换为ds_cnn_s.tflite

    用于图像分类(eIQ tensflowlite 库)。从广义上讲,我正在寻找该脚本,您可能已经使用该脚本 DS_CNN_S.pb 转换为 ds_cnn_s.tflite我能够查看两个模型
    发表于 04-19 06:11

    Pytorch模型转换为DeepViewRT模型时出错怎么解决?

    我最终可以在 i.MX 8M Plus 处理器上部署 .rtm 模型。 我遵循了 本指南,我 Pytorch 模型转换为 ONNX 模型
    发表于 06-09 06:42

    如何将Detectron2和Layout-LM模型转换为OpenVINO中间表示(IR)和使用CPU插件进行推断?

    无法确定如何将 Detectron2* 和 Layout-LM* 模型转换为OpenVINO™中间表示 (IR) 和使用 CPU 插件进行推断。
    发表于 08-15 06:23

    数学原理:如何将ADC代码转换为电压(第1篇)

    许多初步了解模数转换器(ADC)的人想知道如何将ADC代码转换为电压。或者,他们的问题是针对特定应用,例如:如何将ADC代码转换回物理量,如
    发表于 04-18 03:30 3934次阅读

    如何将Altera的SDC约束转换为Xilinx XDC约束

    了解如何将Altera的SDC约束转换为Xilinx XDC约束,以及需要更改或修改哪些约束以使Altera的约束适用于Vivado设计软件。
    的头像 发表于 11-27 07:17 5077次阅读

    Android中使用TFLite c++部署

    之前的文章,我们跟大家介绍过如何使用NNAPI来加速TFLite-Android的inference(可参考使用NNAPI加速android-tflite的Mobilenet分类器...
    发表于 02-07 11:57 7次下载
    在<b class='flag-5'>Android</b>中使用<b class='flag-5'>TFLite</b> c++部署

    如何将简单的汽车转换为无线遥控汽车

    电子发烧友网站提供《如何将简单的汽车转换为无线遥控汽车.zip》资料免费下载
    发表于 10-21 14:51 2次下载
    <b class='flag-5'>如何将</b>简单的汽车<b class='flag-5'>转换为</b>无线遥控汽车

    如何将Android代码转换成JS代码运行

    Autojs这个工具,因为它本身是使用的Rhino引擎开发的,因此它可以把Android代码转换成JavaScript语法的代码来运行,Autojs提供了几个相关的方法来辅助
    的头像 发表于 03-03 14:05 2538次阅读