0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

TVM学习之从relay到TOPI

ytrwv 来源:ytrwv 作者:ytrwv 2022-08-02 10:16 次阅读

Lower操作完成从高级算子(relay)到低级算子(TOPI)的转化。Lower开始于以下代码(src/relay/backend/graph_runtime_codegen.cc):

 LoweredOutput Codegen(relay::Function func) {
    auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
    storage_device_map_ = (*pf)(func);
    // First we convert all the parameters into input nodes.
    for (auto param : func->params) {
      auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
      var_map_[param.get()] = AddNode(node_ptr, param);
    }
    heads_ = VisitExpr(func->body);
    std::ostringstream os;
    dmlc::JSONWriter writer(&os);
    GetJSON(&writer);
    LoweredOutput ret;
    ret.graph_json = os.str();
    ret.params = params_;


    for (auto& kv : lowered_funcs_) {
      if (ret.lowered_funcs.count(kv.first) == 0) {
        ret.lowered_funcs.Set(kv.first, IRModule());
      }
      auto& mod = ret.lowered_funcs[kv.first];
      mod->Update(kv.second);
      ret.lowered_funcs.Set(kv.first, mod);
    }
    ret.external_mods = compile_engine_->LowerExternalFunctions();
    return ret;
  }

在完成内存申请优化之后,VisitExpr对图进行遍历并lower每个relay算子。我们来看CallNode节点的处理。主要代码如下:

auto pf0 = GetPackedFunc("relay.backend._make_CCacheKey");
    auto pf1 = GetPackedFunc("relay.backend._CompileEngineLower");
    Target target;
    // Handle external function
    if (func->GetAttr(attr::kCompiler).defined()) {
      target = tvm::target::ext_dev();
      CCacheKey key = (*pf0)(func, target);
      CachedFunc ext_func = (*pf1)(compile_engine_, key);
这一步是当存在外部compiler的时候,使用外部compiler进行lower。CCacheKey将function和target打包到一起,可能是方便后边compiler的调用。而lower函数会调用src/relay/backend/compile_engine.cc中CompileEngineImpl类中的LowerInternal函数,在这个函数中实现了外部编译器lower和内部lower的代码,如果是有外部compiler参与,其将function,target等打包成CCacheValue返回,等待后边外部编译器统一处理。
如果没有外部编译器,那么TVM将对relay算子转换到TOPI库中算子。
CachedFunc lowered_func = (*pf1)(compile_engine_, key);
    if (!lowered_funcs_.count(target->str())) {
      lowered_funcs_[target->str()] = IRModule();
    }
    lowered_funcs_[target->str()]->Update(lowered_func->funcs);
return GraphAddCallNode(op, _GetUniqueName(lowered_func->func_name), lowered_func->func_name);

同样会调用LowerInternal函数,首先建立schedule:

CachedFunc CreateSchedule(const Function& source_func, const Target& target) {
    return ScheduleGetter(target).Create(source_func);
  }

在Create函数中,首先将inputs都转换成te的算子表示:

for (Var param : prim_func-> params) {
      Array  inputs;
      if (const auto* ttype = param->checked_type().as< TensorTypeNode>()) {
        tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype-> shape), ttype->dtype);
        cache_node-> inputs.push_back(tensor);
        inputs.push_back(tensor);
      } else {
        // flatten tuple of tensor type.
        const auto* tuple_type = param-> type_as ();
        for (Type field : tuple_type-> fields) {
          const auto* ttype = field.as< TensorTypeNode> ();
          // TODO(@icemelon): Allow recursive tuple
          CHECK(ttype != nullptr);
          tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype-> shape), ttype-> dtype);
          cache_node-> inputs.push_back(tensor);
          inputs.push_back(tensor);
        }
      }
      memo_[param] = inputs;
}

然后遍历其它node来实现lower操作。

我们还是来看CallNode的访问。

Array VisitExpr_(const CallNode* call_node) final {
    static auto fpattern = Op::GetAttrMap("TOpPattern");
    static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
    CHECK(flower_call) << "relay.backend.lower_call is not registered.";


    Array inputs;
    int count_tuple = 0;
    for (Expr arg : call_node->args) {
      if (arg->checked_type().as()) {
        ++count_tuple;
      }
      for (te::Tensor tensor : VisitExpr(arg)) {
        inputs.push_back(tensor);
      }
    }
    if (count_tuple) {
      CHECK_EQ(call_node-> args.size(), 1U) << "Only allow function with a single tuple input";
    }


    CHECK(call_node->op.as>OpNode> ()) >> "Primitive function only allows call into primitive ops";
    Op op = Downcast>Op>(call_node-> op);


    Array>te::Tensor> outputs;
    OpImplementation impl;
    // Skip fcompute for device copy operators as it is not registered.
    if (op == device_copy_op_) {
      const auto* copy_input = inputs[0].operator->();
      outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0));
    } else {
      LoweredOutput lowered_out = (*flower_call)(GetRef>Call>(call_node), inputs, target_);
      outputs = lowered_out->outputs;

这里lower操作会去调用python中注册的lower_call函数,这个函数位于python/tvm/relay/backend/compile_engine.py中。在这个函数中最主要的是select_implementation。

Select_implementation是去选择relay算子的一个TOPI层级的实现方式。同一个relay算子在不同target上有不同实现方式,具体采用哪种方式要依据target的属性。在select_implementation中首先通过gat_valid_implementation获得所有已经注册的实现方式。

fstrategy = op.get_attr("FTVMStrategy")
    assert fstrategy is not None, "%s doesn't have FTVMStrategy registered" % op.name
    with target:
        strategy = fstrategy(attrs, inputs, out_type, target)
    analyzer = tvm.arith.Analyzer()
    ret = []
    for spec in strategy.specializations:
        if spec.condition:
            # check if all the clauses in the specialized condition are true
            flag = True
            for clause in spec.condition.clauses:
                clause = analyzer.canonical_simplify(clause)
                if isinstance(clause, tvm.tir.IntImm) and clause.value:
                    continue
                flag = False
                break
            if flag:
                for impl in spec.implementations:
                    ret.append(impl)
        else:
            for impl in spec.implementations:
                ret.append(impl)
return ret

fstrategy指向的是op attr的"FTVMStrategy"对应的函数。比如con2d注册的策略有:

def conv2d_strategy(attrs, inputs, out_type, target):
    """conv2d generic strategy"""
    logger.warning("conv2d is not optimized for this platform.")
    strategy = _op.OpStrategy()
    data, kernel = inputs
    dilation = get_const_tuple(attrs.dilation)
    groups = attrs.groups
    layout = attrs.data_layout
    kernel_layout = attrs.kernel_layout
    (dilation_h, dilation_w) = dilation
    if dilation_h > 1 or dilation_w > 1:
        raise ValueError("dilation should be positive value")


    if groups == 1:
        if layout == "NCHW":
            assert kernel_layout == "OIHW"
            strategy.add_implementation(
                wrap_compute_conv2d(topi.nn.conv2d_nchw),
                wrap_topi_schedule(topi.generic.schedule_conv2d_nchw),
                nam)

可见一个conv2d即使同一个target也会注册不同的策略。Add_implementation将会把compute,schedule的具体函数注册到strategy中。Strategy是一个包含了一个relay算子implementation方式的数据结构。它包含了很多OpSpecialization,每个OpSpecialization中包含一些列OpImplementation,OpImplementation中就对应着schedule和compute的具体方式,schedule是一个算子计算的排布,compute是对应了TOPI库算子。

获得了所有有效implementation之后,会依据两种方式选择,一种是通过auto TVM来自动化搜索最优的实现方式,另外一种在不适用auto TVM工具情况下,会选择plevel最大的implementation。选择好了implementation之后,就调用src/relay/backend/compile_engine.cc中的LoweredOutput类建立一个实例。可以看出,lower_call实现了将relay算子统一用更底层的的抽象进行了表示。这种表示中包含了relay算子,以及这个算子的计算方式以及schedule信息。这样就方便后边对其进行schedule优化了。

然后将这些LoweredOutput进行打包成CachedFuncNode。CachedFuncNode会作为后边schedule优化的入参。

审核编辑:汤梓红

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • TVM
    TVM
    +关注

    关注

    0

    文章

    19

    浏览量

    3652
  • relay
    +关注

    关注

    0

    文章

    1

    浏览量

    4445
收藏 人收藏

    评论

    相关推荐

    TVM主要的编译过程解析

    低阶TIR表示。  Codegen:内存分配和硬件可执行程序生成。  图导入  通过一个tensorflow的reception网络来熟悉编译过程,其它深度学习框架也具有类似过程。TVM官网可以
    发表于 01-07 16:59

    TVM整体结构,TVM代码的基本构成

    ,编译器包括了前端和后端,前端主要实现从tensorflow等深度学习框架描述的网络结构形式新表示的转化,后端完成编译器中间表示硬件可执行程序的转化。前端对硬件应该是透明的,它的主要挑战在于如何设计出
    发表于 01-07 17:21

    TVM中将计算算符有哪几种

    TVM中将计算算符分成四种
    发表于 01-26 06:34

    TVM的编译流程

    TVM主要的编译过程
    发表于 02-23 07:43

    什么是frame relay,frame relay概念

    什么是frame relay,frame relay概念 物理拓扑:交换机,存取链路,Trunks,CSU/DSU
    发表于 06-11 09:21 3143次阅读
    什么是frame <b class='flag-5'>relay</b>,frame <b class='flag-5'>relay</b>概念

    《HTML 5 入门精通》-中文学习教程

    《HTML 5 入门精通》-中文学习教程.pdf 《HTML 5 入门精通》-中文学习
    发表于 11-02 17:45 0次下载

    TVM用于移动端常见的ARM GPU,提高移动设备对深度学习的支持能力

    的压力。 TVM是一个端端的IR堆栈,它可以解决学习过程中的资源分配问题,从而轻松实现硬件优化。在这篇文章中,我们将展示如何用TVM/NNVM为ARM Mali GPU生成高效
    的头像 发表于 01-18 13:38 1.1w次阅读

    什么是波场虚拟机TVM

    TVM与现有的开发生态系统无缝连接,并支持 DPoS。 TVM最初与 EVM 环境兼容,因此开发人员可以使用Solidity和其他语言在 Remix 环境中开发,调试和编译智能合约,而不是学习
    发表于 05-15 09:46 3224次阅读
    什么是波场虚拟机<b class='flag-5'>TVM</b>

    TVM的编译流程是什么

    TVM主要的编译过程如下图:Import:将tensorflow,onnx,pytorch等构建的深度学习模型导入,转化成TVM的中间层表示IR。Lower:将高层IR表示转化成低阶TIR表示。Codegen:内存分配和硬件可执
    的头像 发表于 02-08 14:51 1662次阅读
    <b class='flag-5'>TVM</b>的编译流程是什么

    TVM学习(三)编译流程

    TVM主要的编译过程如下图:Import:将tensorflow,onnx,pytorch等构建的深度学习模型导入,转化成TVM的中间层表示IR。Lower:将高层IR表示转化成低阶TIR表示。Codegen:内存分配和硬件可执
    发表于 01-26 09:23 13次下载
    <b class='flag-5'>TVM</b><b class='flag-5'>学习</b>(三)编译流程

    TVM学习(四)codegen

    接着上一章继续深入代码,在BuildRelay中会调用Codegen函数。这个函数实现在src/relay/backend/graph_runtime_codegen.cc中。Codegen实现了内存的分配,IR节点到TIR节点的转换,tir图节点的一个调度优化。
    发表于 01-27 06:43 8次下载
    <b class='flag-5'>TVM</b><b class='flag-5'>学习</b>(四)codegen

    TVM学习(二):算符融合

    算符融合将多个计算单元揉进一个计算核中进行,减少了中间数据的搬移,节省了计算时间。TVM中将计算算符分成四种: 1 injective。一一映射函数,比如加法,点乘等。 2 reduction。输入
    发表于 02-19 06:50 10次下载
    <b class='flag-5'>TVM</b><b class='flag-5'>学习</b>(二):算符融合

    使用TVM在android中进行Mobilenet SSD部署

    所谓TVM,按照正式说法:就是一种将深度学习工作负载部署硬件的端端IR(中间表示)堆栈。换一种说法,可以表述为一种把深度学习模型...
    发表于 02-07 12:07 0次下载
    使用<b class='flag-5'>TVM</b>在android中进行Mobilenet SSD部署

    TVM学习(八)pass总结

    Pass是TVM中基于relay IR进行的优化,目的是去除冗余算子,进行硬件友好的算子转换,最终能够提高硬件运行效率。由tensorflow等深度学习框架生成的图机构中,含有很多可以优化的算子
    的头像 发表于 08-02 09:43 1895次阅读
    <b class='flag-5'>TVM</b><b class='flag-5'>学习</b>(八)pass总结

    PyTorch教程7.1全连接层卷积

    电子发烧友网站提供《PyTorch教程7.1全连接层卷积.pdf》资料免费下载
    发表于 06-05 11:50 0次下载
    PyTorch教程7.1<b class='flag-5'>之</b><b class='flag-5'>从</b>全连接层<b class='flag-5'>到</b>卷积