0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

TVM学习(八)pass总结

djelje 来源:djelje 作者:djelje 2022-08-02 09:43 次阅读

什么是pass?

Pass是TVM中基于relay IR进行的优化,目的是去除冗余算子,进行硬件友好的算子转换,最终能够提高硬件运行效率。由tensorflow深度学习框架生成的图机构中,含有很多可以优化的算子,比如expand_dim,len等,其实在编译阶段完全可以优化掉,从而能够减少硬件的计算,以及避免出现硬件不支持的算子。

TVM中在include/tvm/ir/transform.h中对pass进行了抽象,主要包括PassContext,PassInfo,Pass,以及Sequential。其中PassContext包含了pass执行依赖的一些参数,比如优化level,analysis report等。PassInfo是一个用于记录pass信息的类,包括pass的opt-level,名称等。和PassContext的区别是PassContext是pass执行所需要获取的条件。Pass就是执行pass的主体,主要就是pass的函数。比如RemoveUnusedFunctions就是执行pass的一个主体函数,目的就是去除冗余算子。Sequential是一个container,装载所有pass。

一些pass

01. RemoveUnusedFunctions

位于src/relay/backend/vm/removed_unused_funcs.cc中,顾名思义就是去除relay IR中的冗余函数。通过从main函数开始遍历,如果一个函数体没有引用其它函数,而同时又没有被其它函数调用,即从relay图上看是一个孤立算子,那么就从IRModule中删除。

 void VisitExpr_(const FunctionNode* func_node) final {
    auto func = GetRef(func_node);
    if (visiting_.find(func) == visiting_.end()) {
      visiting_.insert(func);
      for (auto param : func_node->params) {
        ExprVisitor::VisitExpr(param);
      }
      ExprVisitor::VisitExpr(func_node-> body);
    }
  }

02. ToBasicBlockNormalForm

函数在文件src/relay/trnaforms/to_basic_block_normal_from.cc中。通过遍历IRModule中的每个function,将每个function转换为基本块形式。转换函数是ToBasicBlockNormalFormAux。这个函数包括两个步骤:一是找到基本块(basic block)的边界,TVM中对边界进行了一步抽象,判断每个expr是否属于同一个scope,如果scope相同那么就可以将这些表达式放在一个基本块中;第二步根据每个表达式所属的scope将表达式归属到一个基本块中。

Expr ToBasicBlockNormalFormAux(const Expr& e) {
  // calculate all the dependency between nodes.
  support::Arena arena;
  DependencyGraph dg = DependencyGraph::Create(&arena, e);
  /* The scope of the whole expr is global.
   * The scope of any subexpr, is the lowest common ancestor of all incoming edge.
   * We also record the set of expressions whose scope is lifted.
   */
  std::pair scopes = CalcScope(dg);
  return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);
}

DependencyGraph是一个表达式相互依赖的图结构,通过遍历图中每个节点,找到每个节点的scope。CalcScope在文件src/relay/transforms/to_a_normal_from.cc中。这个函数中重点关注以下代码:

…
        s = LCA(s, expr_scope.at(iit->value));
…
    if (n->new_scope) {
      auto child_scope = std::make_shared(s);
      expr_scope.insert({n, child_scope});
    } else {
      expr_scope.insert({n, s});
}

LCA是获得当前节点的父节点的scope的LCA(least common ancestor),然后将这个scope作为这个节点的scope。了解基本块原理的都知道,寻找基本块首先要找到首指令的位置,然后一个首指令到下一个首指令之间的指令就属于一个基本块。而首指令就是那些具有条件和无条件跳转的指令。在TVM中通过new_scope来标记这些节点,比如Ifnode,FunctionNode,LetNode在建立dependency图的时候,这些节点就被标记为new_scope。这样就建立了dependency节点到scope节点的对应map。同时scope节点也被建立起树结构。

接下来就是建立Fill类,这个类中包含了dependency图以及scope的信息,通过其函数ToBasicBlockNormalForm实现基本块转换。它的基本逻辑通过VisitExpr函数遍历dependency节点,将具有相同scope的节点压入到同一个let_list中。Let_list文档中是这样解释的:

/*!
 * \file let_list.h
 * \brief LetList record let binding and insert let expression implicitly.
 *  using it, one can treat AST as value instead of expression,
 *  and pass them around freely without fear of AST explosion (or effect duplication).
 *  for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'.
 *  if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);',
 *  the AST will contain 2 'a', as b and c are now variables.

Let_list使得抽象语法树简洁化,不会因为变量的复制导致树的爆炸。具有相同的scope的expr被约束到相同的let_list中,用一个var来表达,这样就将表达式转化为var的形式。一个var也就对应了一个基本块。

03. Legalize

Legalize是实现等价函数的转换。主要代码在src/relay/transforms/legalize.cc中。主函数是:

Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name) {
  auto rewriter = Legalizer(legalize_map_attr_name);
  return PostOrderRewrite(expr, &rewriter);
}

在legalize.cc文件中定义了一个继承了ExprRewriter的类,在这个类中实现了对function的替换。我们追踪一下调用的过程。PostOrderRewrite在文件src/relay/ir/expr_functor.cc中。首先建立一个PostOrderRewriter类,然后访问每个节点。在访问节点过程中调用了ExpandDataFlow函数,看一下这个函数的描述:

*
 * ExpandDataflow manually manages a stack and performs DFS to determine the processing
 * order of nodes in an input graph.
 *
 * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
 * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
 * and continues iteratively to process the top of the stack. When it finds a node that doesn't
 * match the dataflow types, or a node who's inputs have all been processed, it visits the current
 * leaf via fvisit_leaf.
 *
 * This function should be used internally to other classes to implement mixed-mode traversals. The
 * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
 * hits a non-dataflow node.
 *
 * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
 */

主要目的是有区别的去处理graph中的节点,如果fcheck_visited已经确定该节点处理过或者不需要处理,就跳过,通过fvisit_leaf继续访问下一个节点。而在VisitLeaf函数中就调用了legalizer类中的rewrite_函数实现了legalize功能。在Rewrite_中,通过映射表legalize_map_attr_name实现函数的等价转换。

04. SimplifyInference

实现对batch normalization, layer normalization, instance normalization, group normalization, L2 normalization算子的分解,这样做的目的是可以在之后的优化中,将这些算子融合到其它算子上,减少计算量。代码在src/relay/transforms/simplify_inference.cc中。文件中定义了一个InferenceSimplifier类来处理这个问题。看一下这几个normalization的公式:

1 BN:

pYYBAGGYIDKAMYXkAALAFPdMTWI678.png

2 LN:获得均值和方差是基于同一层不同神经元的数据。归一化公式相同。

3 GN: 将每个输入样本沿着通道进行分组,在每个组内进行归一化。

4 IN:对每个通道的数据进行归一化。

来看一下bacth normalization的处理代码:

Expr BatchNormToInferUnpack(const Attrs attrs, Expr data, Expr gamma, Expr beta, Expr moving_mean,
                            Expr moving_var, Type tdata) {
  auto ttype = tdata.as();
  CHECK(ttype);
  const auto param = attrs.as< BatchNormAttrs>();
  Expr epsilon = MakeConstantScalar(ttype->dtype, static_cast(param->epsilon));
  Expr var_add_eps = Add(moving_var, epsilon);
  Expr sqrt_var = Sqrt(var_add_eps);
  Expr scale = Divide(MakeConstantScalar(ttype->dtype, 1.0f), sqrt_var);


  if (param->scale) {
    scale = Multiply(scale, gamma);
  }
  Expr neg_mean = Negative(moving_mean);
  Expr shift = Multiply(neg_mean, scale);
  if (param->center) {
    shift = Add(shift, beta);
  }


  auto ndim = ttype->shape.size();
  int axis = (param->axis <  0) ? param->axis + ndim : param->axis;
  scale = ExpandBiasToMatchAxis(scale, ndim, {axis});
  shift = ExpandBiasToMatchAxis(shift, ndim, {axis});


  Expr out = Multiply(data, scale);
  out = Add(out, shift);
  return out;
}

可以看到就是将batch norm算子分解成最基本的加减乘除算子。

05. EliminateCommonSubexpr

顾名思义,这个pass的目的是消除公共子表达式。公共子表达式类似这种:

a=b+c

d=b+c

两个表达式具有相同的op,同时又有相同的args,而且args的顺序也一样。那么就可以用一个表达式替换。

这个pass的实现在文件src/relay/transforms/eliminate_common_subexpr.cc中。TVM定义了类CommonSubexprEliminator来处理。重载函数Rewrite_实现了对expr的遍历和重写操作。

 Expr Rewrite_(const CallNode* call, const Expr& post) final {
…
    if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef< Op>(op), false)) {
      return new_expr;
    }
    if (fskip_ != nullptr && fskip_(new_expr)) {
      return new_expr;
    }


    auto it = expr_map_.find(new_call->op);
    if (it != expr_map_.end()) {
      for (const Expr& candidate_expr : it->second) {
        if (const CallNode* candidate = candidate_expr.as< CallNode>()) {
          bool is_equivalent = true;
          if (!attrs_equal(new_call->attrs, candidate->attrs)) {
            continue;
          }
          for (size_t i = 0; i <  new_call->args.size(); i++) {
            if (!new_call->args[i].same_as(candidate->args[i]) &&
                !IsEqualScalar(new_call->args[i], candidate->args[i])) {
              is_equivalent = false;
              break;
            }
          }
          if (!is_equivalent) continue;
          return GetRef(candidate);
        }
      }
    }
    expr_map_[new_call->op].push_back(new_expr);
    return new_expr;
  }

使用一个expr_map_映射记录已经遍历过的具有相同op的expr,之后每次遇到相同的op都会对已经记录的expr进行匹配,匹配包括attrs以及args,如果二者都一样的话,证明就是公共子表达式。

没有看过的pass

以上是实现相对简单的pass,TVM中还实现了其它很多pass,就没有一一去读代码了。以后看需要再去读吧。现在做一些罗列:

1 SimplifyExpr

简化一些表达式,具体如何进行简化需要读代码了。

2 CombineParallelConv2D

合并多分支并行的conv2d运算,理解是对多个batch的conv2d进行合并。

3 CombineParalleleDense

将多个batch的dense操作合并为一个batch_matmul操作。

4 CombineParallelBatchMatmul

对多个并行的batch_mamul再进行合并。

这几个combine操作可能是针对GPU器件的一个多数据并行性的优化。

5 FoldConstant

典型的一个常量合并优化。

6 FoldScaleAxis

包含了ForwardFoldScaleAxis和backwardFoldScaleAxis,主要是将scale参数合并到conv/dense操作的权重参数中。

7 CanonicalizeCast

官方解释是: Canonicalize cast expressions to make operator fusion more efficient。理解是对一些cast操作规范化,就是让复杂的cast操作可以更简洁。

8 CanonicalizeOps

规范化一些算子,比如bias_add能够被表示为expand_dims和broadcast_add操作。

审核编辑 黄昊宇

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 优化
    +关注

    关注

    0

    文章

    220

    浏览量

    23905
  • TVM
    TVM
    +关注

    关注

    0

    文章

    19

    浏览量

    3662
收藏 人收藏

    评论

    相关推荐

    TVM主要的编译过程解析

    `  TVM主要的编译过程如下图:    Import:将tensorflow,onnx,pytorch等构建的深度学习模型导入,转化成TVM的中间层表示IR。  Lower:将高层IR表示转化成
    发表于 01-07 16:59

    TVM整体结构,TVM代码的基本构成

    图:    Frontend:这个就是将来自不同深度学习框架中的神经网络转化成TVM自己的IR表示。神经网络模型的输入是protoBuf文件,比如在tensorflow中就是pbtxt文件,这个文件中
    发表于 01-07 17:21

    TVM中将计算算符有哪几种

    TVM中将计算算符分成四种
    发表于 01-26 06:34

    TVM的编译流程

    TVM主要的编译过程
    发表于 02-23 07:43

    SOPC Builder/Nios 学习经验总结

    SOPC Builder/Nios 学习经验总结
    发表于 07-22 15:32 0次下载
    SOPC Builder/Nios <b class='flag-5'>学习</b>经验<b class='flag-5'>总结</b>

    FPGA学习总结[经典推荐]

    单片机(Microcontrollers)学习,FPGA学习总结[经典推荐],感兴趣的小伙伴可以瞧一瞧。
    发表于 11-03 15:15 155次下载

    ARM寄存器学习总结

    ARM寄存器学习总结
    发表于 01-04 15:10 0次下载

    TVM用于移动端常见的ARM GPU,提高移动设备对深度学习的支持能力

    的压力。 TVM是一个端到端的IR堆栈,它可以解决学习过程中的资源分配问题,从而轻松实现硬件优化。在这篇文章中,我们将展示如何用TVM/NNVM为ARM Mali GPU生成高效kernel
    的头像 发表于 01-18 13:38 1.1w次阅读

    什么是波场虚拟机TVM

    TVM与现有的开发生态系统无缝连接,并支持 DPoS。 TVM最初与 EVM 环境兼容,因此开发人员可以使用Solidity和其他语言在 Remix 环境中开发,调试和编译智能合约,而不是学习
    发表于 05-15 09:46 3256次阅读
    什么是波场虚拟机<b class='flag-5'>TVM</b>

    Linux的基础学习笔记资料总结

    本文档的主要内容详细介绍的是Linux的基础学习笔记资料总结包括了:一、 常用命令,二、 磁盘管理,三、 用户管理,四、 文件权限,五、 目录结构,六、 软件安装,七、 时间管理,、 启动引导,九
    发表于 11-13 08:00 4次下载

    TVM的编译流程是什么

    TVM主要的编译过程如下图:Import:将tensorflow,onnx,pytorch等构建的深度学习模型导入,转化成TVM的中间层表示IR。Lower:将高层IR表示转化成低阶TIR表示。Codegen:内存分配和硬件可执
    的头像 发表于 02-08 14:51 1696次阅读
    <b class='flag-5'>TVM</b>的编译流程是什么

    TVM学习(三)编译流程

    TVM主要的编译过程如下图:Import:将tensorflow,onnx,pytorch等构建的深度学习模型导入,转化成TVM的中间层表示IR。Lower:将高层IR表示转化成低阶TIR表示。Codegen:内存分配和硬件可执
    发表于 01-26 09:23 13次下载
    <b class='flag-5'>TVM</b><b class='flag-5'>学习</b>(三)编译流程

    TVM学习(二):算符融合

    算符融合将多个计算单元揉进一个计算核中进行,减少了中间数据的搬移,节省了计算时间。TVM中将计算算符分成四种: 1 injective。一一映射函数,比如加法,点乘等。 2 reduction。输入
    发表于 02-19 06:50 10次下载
    <b class='flag-5'>TVM</b><b class='flag-5'>学习</b>(二):算符融合

    FDTD学习总结.pdf

    FDTD学习总结.pdf
    发表于 01-17 11:28 0次下载

    使用TVM在android中进行Mobilenet SSD部署

    所谓TVM,按照正式说法:就是一种将深度学习工作负载部署到硬件的端到端IR(中间表示)堆栈。换一种说法,可以表述为一种把深度学习模型...
    发表于 02-07 12:07 0次下载
    使用<b class='flag-5'>TVM</b>在android中进行Mobilenet SSD部署