0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

XLA和PyTorch的链接代码示例

jf_pmFSk4VX 来源:GiantPandaCV 2023-11-17 10:54 次阅读

XLA和PyTorch的链接

前言

XLA (Accelerated Linear Algebra)是一个开源的机器学习编译器,对PyTorch、Tensorflow、JAX等多个深度学习框架都有支持。最初XLA实际上是跟Tensorflow深度结合的,很好地服务了Tensorflow和TPU,而与XLA的结合主要依赖于社区的支持,即torch-xla。

torch-xla在支持XLA编译的基础上,较大限度地保持了PyTorch的易用性,贴一个官方的DDP训练的例子:

importtorch.distributedasdist
-importtorch.multiprocessingasmp
+importtorch_xla.core.xla_modelasxm
+importtorch_xla.distributed.parallel_loaderaspl
+importtorch_xla.distributed.xla_multiprocessingasxmp
+importtorch_xla.distributed.xla_backend

def_mp_fn(rank,world_size):
...

-os.environ['MASTER_ADDR']='localhost'
-os.environ['MASTER_PORT']='12355'
-dist.init_process_group("gloo",rank=rank,world_size=world_size)
+#RankandworldsizeareinferredfromtheXLAdeviceruntime
+dist.init_process_group("xla",init_method='xla://')
+
+model.to(xm.xla_device())
+#`gradient_as_bucket_view=True`requiredforXLA
+ddp_model=DDP(model,gradient_as_bucket_view=True)

-model=model.to(rank)
-ddp_model=DDP(model,device_ids=[rank])
+xla_train_loader=pl.MpDeviceLoader(train_loader,xm.xla_device())

-forinputs,labelsintrain_loader:
+forinputs,labelsinxla_train_loader:
optimizer.zero_grad()
outputs=ddp_model(inputs)
loss=loss_fn(outputs,labels)
loss.backward()
optimizer.step()

if__name__=='__main__':
-mp.spawn(_mp_fn,args=(),nprocs=world_size)
+xmp.spawn(_mp_fn,args=())

将一段PyTorch代码改写为torch-xla代码,主要就是三个方面:

将模型和数据放到xla device上

适当的时候调用xm.mark_step

某些组件该用pytorchx-xla提供的,比如amp和spawn

其中第二条并没有在上面的代码中体现,原因是为了让用户少改代码,torch-xla将mark_step封装到了dataloader中,实际上不考虑DDP的完整训练的过程可以简写如下:

device=xm.xla_device()
model=model.to(device)
fordata,labelinenumerate(dataloader):
data,label=data.to(device),label.to(device)
output=model(data)
loss=func(output,label)
loss.backward()
optimizer.step()
xm.mark_step()

xm.mark_step的作用就是"告诉"框架:现在对图的定义告一段落了,可以编译并执行计算了。既然如此,那么mark_step之前的内容是做了什么呢?因为要在mark_step之后才编译并计算,那么前面肯定不能执行实际的运算。这就引出了Trace和LazyTensor的概念。

其实到了这里,如果对tensorflow或者torch.fx等比较熟悉,就已经很容易理解了,在mark_step之前,torch-xla将torch Tensor换成了LazyTensor,进而将原本是PyTorch中eager computation的过程替换成了trace的过程,最后生成一张计算图来优化和执行。简而言之这个过程是PyTorch Tensor -> XLATensor -> HLO IR,其中HLO就是XLA所使用的IR。在每次调用到torch op的时候,会调用一次GetIrValue,这时候就意味着一个节点被写入了图中。更具体的信息可以参考XLA Tensor Deep Dive这部分文档。需要注意的是,trace这个过程是独立于mark_step的,即便你的每个循环都不写mark_step,这个循环也可以一直持续下去,只不过在这种情况下,永远都不会发生图的编译和执行,除非在某一步trace的时候,发现图的大小已经超出了pytorch-xla允许的上限。

PyTorch与torch-xla的桥接

知晓了Trace过程之后,就会好奇一个问题:当用户执行一个PyTorch函数调用的时候,torch-xla怎么将这个函数记录下来的?

最容易想到的答案是“torch-xla作为PyTorch的一个编译选项,打开的时候就会使得二者建立起映射关系”,但很可惜,这个答案是错误的,仔细看PyTorch的CMake文件以及torch-xla的编译方式就会明白,torch-xla是几乎单向依赖于PyTorch的(为什么不是全部后面会讲)。既然PyTorch本身在编译期间并不知道torch-xla的存在,那么当用户使用一个xla device上的Tensor作为一个torch function的输入的时候,又经历了怎样一个过程调用到pytorch-xla中的东西呢?

从XLATensor开始的溯源

尽管我们现在并不知道怎么调用到torch-xla中的,但我们知道PyTorch Tensor一定要转换成XLATensor(参考tensor.h),那么我们只需要在关键的转换之处打印出调用堆栈,自然就可以找到调用方,这样虽然不能保证找到PyTorch中的位置,但是能够找到torch-xla中最上层的调用。注意到XLATensor只有下面这一个创建函数接受at::Tensor作为输入,因此就在这里面打印调用栈。

XLATensorXLATensor::Create(constat::Tensor&tensor,constDevice&device)

测试的用例很简单,我们让两个xla device上的Tensor相乘:

importtorch_xla.core.xla_modelasxm
importtorch

device=xm.xla_device()
a=torch.normal(0,1,(2,3)).to(device)
b=torch.normal(0,1,(2,3)).to(device)

c=a*b

在上述位置插入堆栈打印代码并重新编译、安装后运行用例,可以看到以下输出(截取部分):

usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZN9torch_xla15TensorToXlaDataERKN2at6TensorERKNS_6DeviceEb+0x64d)[0x7f086098b9ed]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZNK9torch_xla9XLATensor19GetIrValueForTensorERKN2at6TensorERKNS_6DeviceE+0xa5)[0x7f0860853955]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZNK9torch_xla9XLATensor10GetIrValueEv+0x19b)[0x7f0860853d5b]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZN9torch_xla9XLATensor3mulERKS0_S2_N3c108optionalINS3_10ScalarTypeEEE+0x3f)[0x7f086087631f]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(_ZN9torch_xla18XLANativeFunctions3mulERKN2at6TensorES4_+0xc4)[0x7f08606d4da4]
/usr/local/lib/python3.8/dist-packages/_XLAC.cpython-38-x86_64-linux-gnu.so(+0x19d158)[0x7f08605f7158]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(_ZN2at4_ops10mul_Tensor10redispatchEN3c1014DispatchKeySetERKNS_6TensorES6_+0xc5)[0x7f0945c9d055]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(+0x2b8986c)[0x7f094705986c]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(+0x2b8a37b)[0x7f094705a37b]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so(_ZN2at4_ops10mul_Tensor4callERKNS_6TensorES4_+0x157)[0x7f0945cee717]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so(+0x3ee91f)[0x7f094e4b391f]
/usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_python.so(+0x3eeafb)[0x7f094e4b3afb]
python()[0x5042f9]

明显可以看到是从python的堆栈调用过来的,分析一下可以得知_ZN2at4_ops10mul_Tensor10redispatchEN3c1014DispatchKeySetERKNS_6TensorES6_+0xc5对应的定义是at::DispatchKeySet, at::Tensor const&, at::Tensor const&)+0xc5

虽然这里意义仍有些不明,但我们已经可以做出推测了:redistpatch函数是根据DispatchKeySet来决定将操作dispatch到某个backend上,xla的device信息就被包含在其中。而后面两个输入的const at::Tensor&就是乘法操作的两个输入。

根据上面的关键字redispatch来寻找,我们可以找到这样一个文件gen.py,其中的codegen函数很多,但最显眼的是下面的OperatorGen:

@dataclass(frozen=True)
classComputeOperators:
target:Union[
Literal[Target.DECLARATION],
Literal[Target.DEFINITION]
]

@method_with_native_function
def__call__(self,f:NativeFunction)->str:
sig=DispatcherSignature.from_schema(f.func)
name=f.func.name.unambiguous_name()
call_method_name='call'
redispatch_method_name='redispatch'

ifself.targetisTarget.DECLARATION:
returnf"""
structTORCH_API{name}{{
usingschema={sig.type()};
usingptr_schema=schema*;
//SeeNote[staticconstexprchar*membersforwindowsNVCC]
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name,"aten::{f.func.name.name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name,"{f.func.name.overload_name}")
STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str,{cpp_string(str(f.func))})
static{sig.defn(name=call_method_name,is_redispatching_fn=False)};
static{sig.defn(name=redispatch_method_name,is_redispatching_fn=True)};
}};"""
elifself.targetisTarget.DEFINITION:
defns=f"""
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name},name,"aten::{f.func.name.name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name},overload_name,"{f.func.name.overload_name}")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name},schema_str,{cpp_string(str(f.func))})

//aten::{f.func}
staticC10_NOINLINEc10::TypedOperatorHandle<{name}::schema>create_{name}_typed_handle(){{
returnc10::singleton()
.findSchemaOrThrow({name}::name,{name}::overload_name)
.typed<{name}::schema>();
}}
"""

foris_redispatching_fnin[False,True]:
ifis_redispatching_fn:
dispatcher_exprs_str=','.join(['dispatchKeySet']+[a.nameforainsig.arguments()])
dispatcher_call='redispatch'
method_name=f'{name}::{redispatch_method_name}'
else:
dispatcher_exprs_str=','.join([a.nameforainsig.arguments()])
dispatcher_call='call'
method_name=f'{name}::{call_method_name}'

defns+=f"""
//aten::{f.func}
{sig.defn(name=method_name,is_redispatching_fn=is_redispatching_fn)}{{
staticautoop=create_{name}_typed_handle();
returnop.{dispatcher_call}({dispatcher_exprs_str});
}}
"""
returndefns
else:
assert_never(self.target)

对于每个算子,PyTorch会(在编译前)在这里生成许多类,这些类会有静态成员call或者redispatch,其中redispatch负责分发具体的实现。这里的codegen比较繁琐,这里就不再细讲。

注册PyTorch库实现

即便我们找到了上面redispatch和codegen的线索,看起来仍然不足以解释PyTorch到torch-xla的桥接,因为PyTorch和torch-xla两个库之间的调用,必须要有符号的映射才可以,而不是一些函数形式上的相同。PyTorch中是有Dispatcher机制的,这个机制很常见于很多框架,比如oneflow也是有一套类似的Dispatcher机制。这套机制最大的好处就是在尽可能减少侵入式修改的前提下保证了较高的可扩展性。简而言之,我们的op有一种定义,但可以有多种实现方式,并且这个实现的代码可以不在框架内部,这样就使得框架在保持通用性的同时,易于在特定环境下做针对性的扩展。这套机制本质上就是建立了一个字典,将op映射到函数指针,那么每次调用一个op的时候,我们可以根据一些标识(比如tensor.device)来判断应该调用哪一种实现。

PyTorch中提供了一个宏用来将实现注册,从而让dispatcher可以调用:

#define_TORCH_LIBRARY_IMPL(ns,k,m,uid)
staticvoidC10_CONCATENATE(
TORCH_LIBRARY_IMPL_init_##ns##_##k##_,uid)(torch::Library&);
staticconsttorch::TorchLibraryInitC10_CONCATENATE(
TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_,uid)(
torch::IMPL,
c10::if_constexpr(
[](){
return&C10_CONCATENATE(
TORCH_LIBRARY_IMPL_init_##ns##_##k##_,uid);
},
[](){return[](torch::Library&)->void{};}),
#ns,
c10::k),
__FILE__,
__LINE__);
voidC10_CONCATENATE(
TORCH_LIBRARY_IMPL_init_##ns##_##k##_,uid)(torch::Library&m)

这个宏如果完全展开会是下面这样:

staticvoidTORCH_LIBRARY_IMPL_init_aten_CPU_0(torch::Library&);
staticconsttorch::detail::TorchLibraryInitTORCH_LIBRARY_IMPL_static_init_aten_CPU_0(
torch::IMPL,
(c10::CPU)
?&TORCH_LIBRARY_IMPL_init_aten_CPU_0
:[](torch::Library&)->void{}),
"aten",
c10::CPU),
__FILE__,
__LINE__);
voidTORCH_LIBRARY_IMPL_init_aten_CPU_0(torch::Library&m)

这里比较需要注意的是第二行的TORCH_LIBRARY_IMPL_static_init_aten_CPU_0并不是一个函数,而是一个静态变量,它的作用就是在torch_xla库初始化的时候,将xla定义的op注册到PyTorch中。

关于这部分更详细的介绍可以参考https://zhuanlan.zhihu.com/p/648578629。

从PyTorch调用到torch_xla

xla调用上面所说的宏进行注册的位置在RegisterXLA.cpp这个文件中(codegen的结果),如下:

ORCH_LIBRARY_IMPL(aten,XLA,m){
m.impl("abs",
TORCH_FN(wrapper__abs));

...
}

其中,wrapper__abs的定义如下:

at::Tensorwrapper__abs(constat::Tensor&self){
returntorch_xla::abs(self);
}

显然,这个定义和PyTorch框架内部的算子是完全一致的,只是修改了实现。而XLANativeFunctions::abs的实现可以在aten_xla_type.cpp中找到,如下所示:

at::TensorXLANativeFunctions::abs(constat::Tensor&self){
XLA_FN_COUNTER("xla::");
returnbridge::abs(bridge::GetXlaTensor(self)));
}

到这里已经比较明朗了,注册之后,PyTorch上对于op的调用最终会进入torch_xla的native function中调用对应的op实现,而这些实现的根本都是对XLATensor进行操作,在最终操作执行完成之后,会将作为结果的XLATensor重新转换为torch Tensor,但要注意,这里的结果不一定被实际计算了,也可能只是记录了一下IR,将节点加入图中,这取决于具体的实现。

总结

其实torch-xla官方的文档里是有关于代码生成和算子注册这个过程的描述的,只不过一开始我没找到这个文档,走了一点弯路,但是自己探索也会觉得更明了这个过程。官方文档中的描述如下(节选):

All file mentioned below lives under the xla/torch_xla/csrc folder, with the exception of codegen/xla_native_functions.yaml

xla_native_functions.yaml contains the list of all operators that are lowered. Each operator name must directly match a pytorch operator listed in native_functions.yaml. This file serves as the interface to adding new xla operators, and is an input to PyTorch's codegen machinery. It generates the below 3 files: XLANativeFunctions.h, RegisterXLA.cpp, and RegisterAutogradXLA.cpp

XLANativeFunctions.h and aten_xla_type.cpp are entry points of PyTorch to the pytorch_xla world, and contain the manually written lowerings to XLA for each operator. XLANativeFunctions.h is auto-generated through a combination of xla_native_functions.yaml and the PyTorch core native_functions.yaml file, and contains declarations for kernels that need to be defined in aten_xla_type.cpp. The kernels written here need to construct 'XLATensor' using the input at::Tensor and other parameters. The resulting XLATensor needs to be converted back to the at::Tensor before returning to the PyTorch world.

RegisterXLA.cpp and RegisterAutogradXLA.cpp are auto-generated files that register all lowerings to the PyTorch Dispatcher. They also include auto-generated wrapper implementations of out= and inplace operators.

大概意思就是实际上torch-xla就是根据xla_native_functions.yaml这个文件来生成算子的定义,然后再生成对应的RegisterXLA.cpp中的注册代码,这也跟PyTorch的codegen方式一致。

综合这一整个过程可以看出,PyTorch是保持了高度的可扩展性的,不需要多少侵入式的修改就可以将所有的算子全部替换成自己的,这样的方式也可以让开发者不用去关注dispatcher及其上层的实现,专注于算子本身的逻辑。

编辑:黄飞

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 编译器
    +关注

    关注

    1

    文章

    1616

    浏览量

    49010
  • 机器学习
    +关注

    关注

    66

    文章

    8344

    浏览量

    132289
  • 深度学习
    +关注

    关注

    73

    文章

    5462

    浏览量

    120875
  • pytorch
    +关注

    关注

    2

    文章

    802

    浏览量

    13110

原文标题:XLA和PyTorch的链接

文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。

收藏 人收藏

    评论

    相关推荐

    Pytorch自动求导示例

    Pytorch自动微分的几个例子
    发表于 08-09 11:56

    TensorFlow XLA加速线性代数编译器

    编译:在会话级别中打开JIT编译: 这是手动打开 JIT 编译: 还可以通过将操作指定在特定的 XLA 设备(XLA_CPU 或 XLA_GPU)上,通过 XLA 来运行计算: Ao
    发表于 07-28 14:31

    excel vba代码 示例讲解

    excel vba代码 示例讲解
    发表于 09-07 09:36 25次下载
    excel vba<b class='flag-5'>代码</b> <b class='flag-5'>示例</b>讲解

    Pytorch入门教程与范例

    pytorch 是一个基于 python 的深度学习库。pytorch 源码库的抽象层次少,结构清晰,代码量适中。相比于非常工程化的 tensorflow,pytorch 是一个更易入
    发表于 11-15 17:50 5354次阅读
    <b class='flag-5'>Pytorch</b>入门教程与范例

    AD593X代码示例

    AD593X代码示例
    发表于 03-23 08:18 14次下载
    AD593X<b class='flag-5'>代码</b><b class='flag-5'>示例</b>

    BeMicro代码示例

    BeMicro代码示例
    发表于 05-10 12:21 0次下载
    BeMicro<b class='flag-5'>代码</b><b class='flag-5'>示例</b>

    ezLINX™示例PC应用程序源代码

    ezLINX™示例PC应用程序源代码
    发表于 06-05 19:12 1次下载
    ezLINX™<b class='flag-5'>示例</b>PC应用程序源<b class='flag-5'>代码</b>

    机器学习必学的Python代码示例

    机器学习必学的Python代码示例
    发表于 06-21 09:35 14次下载

    华为游戏服务示例代码教程案例

    概述 游戏服务kit安卓示例代码集成了华为游戏服务的众多API,提供了示例代码程序供您参考和使用,下面是对示例
    发表于 04-11 11:09 4次下载

    基于keil的AD7366示例代码

    基于keil的AD7366示例代码分享
    发表于 10-08 14:58 3次下载

    RAA489204 示例代码软件手册

    RAA489204 示例代码软件手册
    发表于 01-10 18:52 0次下载
    RAA489204 <b class='flag-5'>示例</b><b class='flag-5'>代码</b>软件手册

    RAA489204 示例代码软件手册

    RAA489204 示例代码软件手册
    发表于 06-30 19:23 0次下载
    RAA489204 <b class='flag-5'>示例</b><b class='flag-5'>代码</b>软件手册

    安全驱动示例代码和实现

    示例代码获取和集成 本示例中的驱动只实现了对内存的读写操作,并提供了测试使用的TA和CA。 读者可使用如下指令从GitHub上获取到示例代码
    的头像 发表于 10-30 16:07 592次阅读
    安全驱动<b class='flag-5'>示例</b><b class='flag-5'>代码</b>和实现

    自己编写函数示例代码很难吗?分享几个示例

    Q A 问: Arduino Uno的函数示例 我决定自己编写函数示例代码,因为这应该是Arduino中的基本示例。网络上确实有关于使用函数的文档,但是,如果要尝试使用
    的头像 发表于 11-16 16:05 476次阅读
    自己编写函数<b class='flag-5'>示例</b><b class='flag-5'>代码</b>很难吗?分享几个<b class='flag-5'>示例</b>!

    TorchFix:基于PyTorch代码静态分析

    TorchFix是我们最近开发的一个新工具,旨在帮助PyTorch用户维护健康的代码库并遵循PyTorch的最佳实践。首先,我想要展示一些我们努力解决的问题的示例
    的头像 发表于 12-18 15:20 1015次阅读