发布人:TensorFlow 团队的 Mathieu Guillame-Bert 和 Josh Gordon
随机森林和梯度提升树这类的决策森林模型通常是处理表格数据最有效的可用工具。与神经网络相比,决策森林具有更多优势,如配置过程更轻松、训练速度更快等。使用树可大幅减少准备数据集所需的代码量,因为这些树本身就可以处理数字、分类和缺失的特征。此外,这些树通常还可提供开箱即用的良好结果,并具有可解释的属性。
尽管我们通常将 TensorFlow 视为训练神经网络的内容库,但 Google 的一个常见用例是使用 TensorFlow 创建决策森林。

对数据开展分类的决策树动画
如果您曾使用 2019 年推出的 tf.estimator.BoostedTrees 创建基于树的模型,您可参考本文所提供的指南进行迁移。虽然 Estimator API 基本可以应对在生产环境中使用模型的复杂性,包括分布式训练和序列化,但是我们不建议您将其用于新代码。
如果您要开始一个新项目,我们建议您使用 TensorFlow 决策森林 (TF-DF)。该内容库可为训练、服务和解读决策森林模型提供最先进的算法,相较于先前的方法更具优势,特别是在质量、速度和易用性方面表现尤为出色。
首先,让我们来比较一下使用 Estimator API 和 TF-DF 创建提升树模型的等效示例。
以下是使用 tf.estimator.BoostedTrees 训练梯度提升树模型的旧方法(不再推荐使用)
import tensorflow as tf
# Dataset generators
def make_dataset_fn(dataset_path):
def make_dataset():
data = ... # read dataset
return tf.data.Dataset.from_tensor_slices(...data...).repeat(10).batch(64)
return make_dataset
# List the possible values for the feature "f_2".
f_2_dictionary = ["NA", "red", "blue", "green"]
# The feature columns define the input features of the model.
feature_columns = [
tf.feature_column.numeric_column("f_1"),
tf.feature_column.indicator_column(
tf.feature_column.categorical_column_with_vocabulary_list("f_2",
f_2_dictionary,
# A special value "missing" is used to represent missing values.
default_value=0)
),
]
# Configure the estimator
estimator = boosted_trees.BoostedTreesClassifier(
n_trees=1000,
feature_columns=feature_columns,
n_classes=3,
# Rule of thumb proposed in the BoostedTreesClassifier documentation.
n_batches_per_layer=max(2, int(len(train_df) / 2 / FLAGS.batch_size)),
)
# Stop the training is the validation loss stop decreasing.
early_stopping_hook = early_stopping.stop_if_no_decrease_hook(
estimator,
metric_name="loss",
max_steps_without_decrease=100,
min_steps=50)
tf.estimator.train_and_evaluate(
estimator,
train_spec=tf.estimator.TrainSpec(
make_dataset_fn(train_path),
hooks=[
# Early stopping needs a CheckpointSaverHook.
tf.train.CheckpointSaverHook(
checkpoint_dir=input_config.raw.temp_dir, save_steps=500),
early_stopping_hook,
]),
eval_spec=tf.estimator.EvalSpec(make_dataset_fn(valid_path)))
使用 TensorFlow 决策森林训练相同的模型
import tensorflow_decision_forests as tfdf
# Load the datasets
# This code is similar to the estimator.
def make_dataset(dataset_path):
data = ... # read dataset
return tf.data.Dataset.from_tensor_slices(...data...).batch(64)
train_dataset = make_dataset(train_path)
valid_dataset = make_dataset(valid_path)
# List the input features of the model.
features = [
tfdf.keras.FeatureUsage("f_1", keras.FeatureSemantic.NUMERICAL),
tfdf.keras.FeatureUsage("f_2", keras.FeatureSemantic.CATEGORICAL),
]
model = tfdf.keras.GradientBoostedTreesModel(
task = tfdf.keras.Task.CLASSIFICATION,
num_trees=1000,
features=features,
exclude_non_specified_features=True)
model.fit(train_dataset, valid_dataset)
# Export the model to a SavedModel.
model.save("project/model")
附注
-
虽然在此示例中没有明确说明,但 TensorFlow 决策森林可自动启用和配置早停。
-
可自动构建和优化“f_2”特征字典(例如,将稀有值合并到一个未登录词项目中)。
-
可从数据集中自动确定类别数(本例中为 3 个)。
-
批次大小(本例中为 64)对模型训练没有影响。以较大值为宜,因为这可以增加读取数据集的效率。
TF-DF 的亮点就在于简单易用,我们还可进一步简化和完善上述示例,如下所示。
如何训练 TensorFlow 决策森林(推荐解决方案)
import tensorflow_decision_forests as tfdf
import pandas as pd
# Pandas dataset can be used easily with pd_dataframe_to_tf_dataset.
train_df = pd.read_csv("project/train.csv")
# Convert the Pandas dataframe into a TensorFlow dataset.
train_ds = tfdf.keras.pd_dataframe_to_tf_dataset(train_df, label="my_label")
model = tfdf.keras.GradientBoostedTreeModel(num_trees=1000)
model.fit(train_dataset)
附注
-
我们未指定特征的语义(例如数字或分类)。在这种情况下,系统将自动推断语义。
-
我们也没有列出要使用的输入特征。在这种情况下,系统将使用所有列(标签除外)。可在训练日志中查看输入特征的列表和语义,或通过模型检查器 API 查看。
-
我们没有指定任何验证数据集。每个算法都可以从训练样本中提取一个验证数据集作为算法的最佳选择。例如,默认情况下,如果未提供验证数据集,则 GradientBoostedTreeModel 将使用 10% 的训练数据进行验证。
下面我们将介绍 Estimator API 和 TF-DF 的一些区别。
Estimator API 和 TF-DF 的区别
算法类型
TF-DF 是决策森林算法的集合,包括(但不限于)Estimator API 提供的梯度提升树。请注意,TF-DF 还支持随机森林(非常适用于干扰数据集)和 CART 实现(非常适用于解读模型)。
此外,对于每个算法,TF-DF 都包含许多在文献资料中发现并经过实验验证的变体 [1, 2, 3]。
精确与近似分块的对比
TF1 GBT Estimator 是一种近似的树学习算法。非正式情况下,Estimator 通过仅考虑样本的随机子集和每个步骤条件的随机子集来构建树。
默认情况下,TF-DF 是一种精确的树训练算法。非正式情况下,TF-DF 会考虑所有训练样本和每个步骤的所有可能分块。这是一种更常见且通常表现更佳的解决方案。
虽然对于较大的数据集(具有百亿数量级以上的“样本和特征”数组)而言,有时 Estimator 的速度更快,但其近似值通常不太准确(因为需要种植更多树才能达到相同的质量)。而对于小型数据集(所含的“样本和特征”数组数目不足一亿)而言,使用 Estimator 实现近似训练形式的速度甚至可能比精确训练更慢。
TF-DF 还支持不同类型的“近似”树训练。我们建议您使用精确训练法,并选择使用大型数据集测试近似训练。
推理
Estimator 使用自上而下的树路由算法运行模型推理。TF-DF 使用 QuickScorer 算法的扩展程序。
虽然两种算法返回的结果完全相同,但自上而下的算法效率较低,因为这种算法的计算量会超出分支预测并导致缓存未命中。对于同一模型,TF-DF 的推理速度通常可提升 10 倍。
TF-DF 可为延迟关键应用程序提供 C++ API。其推理时间约为每核心每样本 1 微秒。与 TF SavedModel 推理相比,这通常可将速度提升 50 至 1000 倍(对小型批次的效果更佳)。
多头模型
Estimator 支持多头模型(即输出多种预测的模型)。目前,TF-DF 无法直接支持多头模型,但是借助 Keras Functional API,TF-DF 可以将多个并行训练的 TF-DF 模型组成一个多头模型。
了解详情
您可以访问此网址,详细了解 TensorFlow 决策森林。
如果您是首次接触该内容库,我们建议您从初学者示例开始。经验丰富的 TensorFlow 用户可以访问此指南,详细了解有关在 TensorFlow 中使用决策森林和神经网络的区别要点,包括如何配置训练流水线和关于数据集 I/O 的提示。
您还可以仔细阅读从Estimator 迁移到 Keras API,了解如何从 Estimator 迁移到 Keras。
原文标题:如何从提升树 Estimator 迁移到 TensorFlow 决策森林
文章出处:【微信公众号:谷歌开发者】欢迎添加关注!文章转载请注明出处。
-
Google
+关注
关注
5文章
1784浏览量
58580 -
模型
+关注
关注
1文章
3474浏览量
49891 -
tensorflow
+关注
关注
13文章
330浏览量
60978
原文标题:如何从提升树 Estimator 迁移到 TensorFlow 决策森林
文章出处:【微信号:Google_Developers,微信公众号:谷歌开发者】欢迎添加关注!文章转载请注明出处。
发布评论请先 登录
将TensorFlow模型转换为中间表示 (IR) 时遇到不一致的形状错误怎么解决?
将YOLOv4模型转换为IR的说明,无法将模型转换为TensorFlow2格式怎么解决?
为什么无法使用OpenVINO™模型优化器转换TensorFlow 2.4模型?
为什么无法将自定义EfficientDet模型从TensorFlow 2转换为中间表示(IR)?
宇树科技在物联网方面
快速部署Tensorflow和TFLITE模型在Jacinto7 Soc

TensorFlow是什么?TensorFlow怎么用?
使用TensorFlow进行神经网络模型更新
tensorflow简单的模型训练
keras模型转tensorflow session
如何使用Tensorflow保存或加载模型

搭建树莓派网络监控系统:顶级工具与技术终极指南!
树莓派网络监控系统是一种经济高效且功能多样的解决方案,可用于监控网络性能、流量及整体运行状况。借助树莓派,我们可以搭建一个网络监控系统,实时洞察网络活动,从而帮助识别问题、优化性能并确保网络安全。安装树莓派网络监控系统有诸多益处。树莓派具备以太网接口,还内置了Wi-Fi功能,拥有足够的计算能力和内存,能够在Linux或Windows系统上运行。因此,那些为L

STM32驱动SD NAND(贴片式SD卡)全测试:GSR手环生物数据存储的擦写寿命与速度实测
在智能皮电手环及数据存储技术不断迭代的当下,主控 MCU STM32H750 与存储 SD NAND MKDV4GIL-AST 的强强联合,正引领行业进入全新发展阶段。二者凭借低功耗、高速读写与卓越稳定性的深度融合,以及高容量低成本的突出优势,成为大规模生产场景下极具竞争力的数据存储解决方案。

芯对话 | CBM16AD125Q这款ADC如何让我的性能翻倍?
综述在当今数字化时代,模数转换器(ADC)作为连接模拟世界与数字系统的关键桥梁,其技术发展对众多行业有着深远影响。从通信领域追求更高的数据传输速率与质量,到医疗影像领域渴望更精准的疾病诊断,再到工业控制领域需要适应复杂恶劣环境的稳定信号处理,ADC的性能提升成为推动这些行业进步的重要因素。行业现状分析在通信行业,5G乃至未来6G的发展,对基站信号处理提出了极

史上最全面解析:开关电源各功能电路
01开关电源的电路组成开关电源的主要电路是由输入电磁干扰滤波器(EMI)、整流滤波电路、功率变换电路、PWM控制器电路、输出整流滤波电路组成。辅助电路有输入过欠压保护电路、输出过欠压保护电路、输出过流保护电路、输出短路保护电路等。开关电源的电路组成方框图如下:02输入电路的原理及常见电路1AC输入整流滤波电路原理①防雷电路:当有雷击,产生高压经电网导入电源时

有几种电平转换电路,适用于不同的场景
一.起因一般在消费电路的元器件之间,不同的器件IO的电压是不同的,常规的有5V,3.3V,1.8V等。当器件的IO电压一样的时候,比如都是5V,都是3.3V,那么其之间可以直接通讯,比如拉中断,I2Cdata/clk脚双方直接通讯等。当器件的IO电压不一样的时候,就需要进行电平转换,不然无法实现高低电平的变化。二.电平转换电路常见的有几种电平转换电路,适用于

瑞萨RA8系列教程 | 基于 RASC 生成 Keil 工程
对于不习惯用 e2 studio 进行开发的同学,可以借助 RASC 生成 Keil 工程,然后在 Keil 环境下愉快的完成开发任务。

共赴之约 | 第二十七届中国北京国际科技产业博览会圆满落幕
作为第二十七届北京科博会的参展方,芯佰微有幸与800余家全球科技同仁共赴「科技引领创享未来」之约!文章来源:北京贸促5月11日下午,第二十七届中国北京国际科技产业博览会圆满落幕。本届北京科博会主题为“科技引领创享未来”,由北京市人民政府主办,北京市贸促会,北京市科委、中关村管委会,北京市经济和信息化局,北京市知识产权局和北辰集团共同承办。5万平方米的展览云集

道生物联与巍泰技术联合发布 RTK 无线定位系统:TurMass™ 技术与厘米级高精度定位的深度融合
道生物联与巍泰技术联合推出全新一代 RTK 无线定位系统——WTS-100(V3.0 RTK)。该系统以巍泰技术自主研发的 RTK(实时动态载波相位差分)高精度定位技术为核心,深度融合道生物联国产新兴窄带高并发 TurMass™ 无线通信技术,为室外大规模定位场景提供厘米级高精度、广覆盖、高并发、低功耗、低成本的一站式解决方案,助力行业智能化升级。

智能家居中的清凉“智”选,310V无刷吊扇驱动方案--其利天下
炎炎夏日,如何营造出清凉、舒适且节能的室内环境成为了大众关注的焦点。吊扇作为一种经典的家用电器,以其大风量、长寿命、低能耗等优势,依然是众多家庭的首选。而随着智能控制技术与无刷电机技术的不断进步,吊扇正朝着智能化、高效化、低噪化的方向发展。那么接下来小编将结合目前市面上的指标,详细为大家讲解其利天下有限公司推出的无刷吊扇驱动方案。▲其利天下无刷吊扇驱动方案一

电源入口处防反接电路-汽车电子硬件电路设计
一、为什么要设计防反接电路电源入口处接线及线束制作一般人为操作,有正极和负极接反的可能性,可能会损坏电源和负载电路;汽车电子产品电性能测试标准ISO16750-2的4.7节包含了电压极性反接测试,汽车电子产品须通过该项测试。二、防反接电路设计1.基础版:二极管串联二极管是最简单的防反接电路,因为电源有电源路径(即正极)和返回路径(即负极,GND),那么用二极

半导体芯片需要做哪些测试
首先我们需要了解芯片制造环节做⼀款芯片最基本的环节是设计->流片->封装->测试,芯片成本构成⼀般为人力成本20%,流片40%,封装35%,测试5%(对于先进工艺,流片成本可能超过60%)。测试其实是芯片各个环节中最“便宜”的一步,在这个每家公司都喊着“CostDown”的激烈市场中,人力成本逐年攀升,晶圆厂和封装厂都在乙方市场中“叱咤风云”,唯独只有测试显

解决方案 | 芯佰微赋能示波器:高速ADC、USB控制器和RS232芯片——高性能示波器的秘密武器!
示波器解决方案总述:示波器是电子技术领域中不可或缺的精密测量仪器,通过直观的波形显示,将电信号随时间的变化转化为可视化图形,使复杂的电子现象变得清晰易懂。无论是在科研探索、工业检测还是通信领域,示波器都发挥着不可替代的作用,帮助工程师和技术人员深入剖析电信号的细节,精准定位问题所在,为创新与发展提供坚实的技术支撑。一、技术瓶颈亟待突破性能指标受限:受模拟前端

硬件设计基础----运算放大器
1什么是运算放大器运算放大器(运放)用于调节和放大模拟信号,运放是一个内含多级放大电路的集成器件,如图所示:左图为同相位,Vn端接地或稳定的电平,Vp端电平上升,则输出端Vo电平上升,Vp端电平下降,则输出端Vo电平下降;右图为反相位,Vp端接地或稳定的电平,Vn端电平上升,则输出端Vo电平下降,Vn端电平下降,则输出端Vo电平上升2运算放大器的性质理想运算

ElfBoard技术贴|如何调整eMMC存储分区
ELF 2开发板基于瑞芯微RK3588高性能处理器设计,拥有四核ARM Cortex-A76与四核ARM Cortex-A55的CPU架构,主频高达2.4GHz,内置6TOPS算力的NPU,这一设计让它能够轻松驾驭多种深度学习框架,高效处理各类复杂的AI任务。

米尔基于MYD-YG2LX系统启动时间优化应用笔记
1.概述MYD-YG2LX采用瑞萨RZ/G2L作为核心处理器,该处理器搭载双核Cortex-A55@1.2GHz+Cortex-M33@200MHz处理器,其内部集成高性能3D加速引擎Mail-G31GPU(500MHz)和视频处理单元(支持H.264硬件编解码),16位的DDR4-1600/DDR3L-1333内存控制器、千兆以太网控制器、USB、CAN、
评论