您好,欢迎来电子发烧友网! ,新用户?[免费注册]

您的位置:电子发烧友网>源码下载>数值算法/人工智能>

TensorFlow模型详解与应用

大小:0.2 MB 人气: 2017-09-28 需要积分:1

  DNNLinearCombinedClassifier 类继承于类 Estimator,Estimator 类继承于类 BaseEstimator。BaseEstimator 是一个抽象类,定义了通用的模型训练以及评测的函数接口 (train_model, evaluate_model, infer_model),Estimator 类中用一个统一函数 call_model_fn 来实现 train_model, evaluate_model, infer_model。

  TensorFlow模型详解与应用

  图 7 estimator 的类关系图

  为了更好了解整个过程,我们看看内部函数的调用过程(代码可以参见 estimator/estimator.py):

  TensorFlow模型详解与应用

  图 8 Estmiator 类的函数调用图

  模型训练通过调用 BaseEstimator 的 fit() 接口开始,其调用栈是:fit -》 _train_model -》 _get_train_ops -》_call_model_fn(ModelKeys.TRAIN) -》 _model_fn,最终_model_fn() 产生模型并通过 export 函数将模型输出到 model_dir 对应目录中。

  我们把训练模型的调用过程在代码级别展开,标出关键的几个函数和数据结构,省略不关键的代码,希望能让读者看到训练模型的大致过程:

  TensorFlow模型详解与应用

  图 9 模型训练的调用栈

  评测(evaluate)和预测(predict)的过程与训练(train)大致相同,读者可以通过源代码文件找到对应函数了解。可以看出,整个函数调用栈中最关键的 2 个函数是: input_fn 和 model_fn。input_fn 从输入数据中生成 features 和 labels,features 是一个 Tensor 或者是一个从特征名到 Tensor 的字典,如果 features 是一个 Tensor,程序会给这个 Tensor 一个空字符串的键值,转换成特征名到 Tensor 的字典。labels 是样本的 label 构成的 tensor。input_fn 由应用程序调用者提供实现,返回(features, labels)二元组,要求 tf.get_shape(features)[0] == tf.get_shape(labels)[0],也就是两个 tensor 的行数目得保持一致。model_fn 定义训练和评测模型的具体逻辑,如模型训练产生的误差 (model_fn_ops.loss) 以及训练算子(model_fn_ops.train_op)通过封装在 EstmiatorSpec 的对象中由 training 的 Session 进行调用。每个具体模型需要实现的是自定义的 model_fn。

  DNNLinearCombinedClassifier 是如何实现自己的 model_fn 的呢?本文开头我们给出了它的初始化函数原型,进入初始化函数的实现中我们定位到代码行 model_fn=_dnn_linear_combined_model_fn。

  这个就是 DNNLinearCombinedClassifier 的 model_fn。这个函数的定义如下:

  def_dnn_linear_combined_model_fn(features, labels, mode, params, config= None)

  features 和 labels 大家都已经知道,mode 指定 model_fn 的操作模式,目前支持 3 个值:训练模型 (model_fn.ModeKeys.TRAIN),对模型进行评测 (model_fn.ModeKeys.EVAL),根据输入特征进行预测 (model_fn.ModeKeys.PREDICT),mode 的定义可参见文件 estimator/model_fn.py。params 和 config 参数分别定义模型训练的参数以及模型运行的配置。

非常好我支持^.^

(2) 40%

不好我反对

(3) 60%

TensorFlow模型详解与应用下载

相关电子资料下载

      发表评论

      用户评论
      评价:好评中评差评

      发表评论,获取积分! 请遵守相关规定!