Introduction
首先,整个 Pytorch
的编译栈如图所示:
我们从前后端分别来进行解析
Frontend
假定我们已经知道计算图的定义,如果不知道的话请看计算图介绍
torch._dynamo
TorchDynamo
的作用是从 PyTorch 中抓取计算图。在使用 dynamo
之前,PyTorch 有过许多尝试,例如 TorchScript
和 TorchFX
等,但结果都不尽人意。
在正式介绍 dynamo
前,我们需要介绍一下其表示 FX Graph。
FX Graph
FX Graph 其实就是 TorchFX,是一种图的表示方法,我们可以通过如下示例来获取一个计算图:
FX IR 或者叫 FX Graph 的特点是:
- 定义简单,只有 6 个 Opcode
- 组织简单,很容易写出后续的解析代码
- 对正向图操作快速,并且 Trace 是 python2python 的,调试方便
- 不能表达控制流(例如
if-else
结构)
而在 dynamo
中,也延续了使用 FX IR 来表示计算图的思路,但消除了不能表达控制流的缺陷
接下来的内容有些涉及底层,可以跳过不看,当然也不会特别底层(例如 CPython 的原理)
dynamo
优化原理 (high-level)
TorchDynamo 是一个 Python 级别 的即时(JIT
Just In Time)编译器,用于在不修改 PyTorch 程序的情况下对其进行加速。TorchDynamo 在 Python 调用帧执行函数Frame Evaluation 时,插入了钩子 Hook。钩子会在执行具体的 Python 字节码(ByteCode
)之前对其进行解析,从中捕捉到 PyTorch 运算符并将其转化成 FX Graph,最后用自定义的后端对图进行编译优化,并导出、返回优化后的字节码。
下面进行详细的解释,但在此之前,需要知道 Python 是如何运行的。
Python 运行原理
如下图所示
在之前都知道 Python 是通过解释器执行的,具体的机理如下:
- Python 源代码被编译成一系列中间字节码
- 由 CPython 虚拟机内部
while
循环不断匹配字节码,并执行对应字节码指令case
分支内部的多条 C 函数
例如:
对应的字节码如下:
帧评估 (帧执行)
这里的帧,实际上和平时所说的函数栈有一些类似,下图可表示 Python 中的函数和 frame 之间的关系:
函数的调用栈,实际上就是递归地创建 Frame,执行 Frame 的过程,例如:
运行结果如图:
显然,函数是运行在 Frame 中的,因此我们可以轻松地从 Frame 中获取任何函数需要的信息,包括调用栈中的函数的名字,甚至在这些函数中创建的局部变量。
换而言之,Frame 实际上包含了代码的信息,如果在执行函数前就能知道这个函数的 Frame 的话,那么就可以解析此 Frame 进而完成整个函数的 Trace
如下图所示,这是 dynamo
的整个过程
解析 Frame 是 dynamo
所做的第一步,通过这一步,我们得到了 PyFrameObject
与 PyCodeObject
然而在模型实际运行时,其调用栈非常复杂,如何自动化解析成为了一个问题。PyTorch 提出了 PEP 来自动化地为每个函数额外加上解析 Frame 的行为
事实上,我们很难在 Python 层面想到一种方法,将某一个修改递归地作用在所有的函数栈上。但我们回顾 Python 的运行原理,可以发现所有的 Frame 的评估,都是依赖于 CPython 解释器的。
因此 ,CPython 解释器的帧执行方式(Frame Evaluation)应该是可扩展的,这样用户就可以用自定义的方式进行 Frame Evaluation。
也就是说,如果 dynamo
能够拓展 CPython 的评估方式,在执行默认的执行函数前实现一些额外的 Frame 解析优化的工作,那么问题就迎刃而解了。
我们可以用下图来解释 dynamo
做的事情(鉴于上面的有些抽象)
dynamo 就是这样做的,他在 set_eval_frame 中将默认的 _PyEval_EvalFrameDefault 替换成自定义的执行函数
dynamo 非常聪明的选择在 Python 层做字节码解析,以回调函数的形式传给自定义的帧评估函数。当我们调用 optimizer('inductor')(fn)
时,dynamo 会将 fn
的帧评估函数替换成 dynamo 自定义的,并且传入回调函数。
当然这部分并不需要了解的太过详细,只需要知道 dynamo 是在翻译到字节码的阶段就捕获计算图了
动态编译
在之前我们已经对 FX Graph 有过介绍,这部分实际上是 dynamo
解析字节码后得到的 IR,但仅凭借字节码,实际上是损失了很多信息的(例如运行时信息),如果光从字节码去做 Trace,这和从 AST 出发做 Trace 区别不大,得到的图也只是一个静态图,不符合 dynamo
想要做到的动态编译,我们没办法完整的生成所需要的 IR,因此这里引入了 VariableTracker,用于承载程序运行时我们可以获得输入信息,我们可以将这个当作构成计算图的节点。
对于动态编译而言,我们期望解决的最大问题就是:不要总是编译。也就是说我们希望他和静态编译相差不多,能够做到一次编译,数次运行,但又具备动态编译的特点,即该编译时才编译。这就需要 Guard
来做到.
在构建 VariableTracker 时,可能会绑定一个或多个 Guard,用于生成监视变量的检查代码, 也就是我们最初提到的 check_fn。 需要注意的是,Graph trace 阶段可能会生成非常多的 Guard,但是最后只有部分 Guard 会被用于生成 check_fn
,这其实也很好理解,因为只有部分变量都会造成模型的动态结构。
对于一个 class Model(nn.Module)
,其 Guard 的输出结果如下:
对于一个 Tensor x
,其 Guard 的输出如下:
Guard 的检查有以下几种:
-
检查变量 id 是否相等(
ID_MATCH
),check_fn
会调用以下函数这里的
check_obj_id
对于上面输出结果中的code
部分, 检查self
参数时,check_obj_id
会根据其 id 是否匹配,来决定是否需要进行重新编译 -
检查 Tensor 是否匹配(
TENSOR_MATCH
),check_fn
会调用以下函数- 数据类型是否发生变化,例如原来数据类型为
float32
,第二次输入时类型变成float16
,返回False
- 数据所在设备是否发生变化,例如原来是在
GPU 0
上的,第二次输入变成在GPU 1
上了,返回False
- 数据的梯度属性是否发生变化,例如原来是需要计算梯度的,第二次却不再要求计算梯度,返回
False
- 数据的形状以及内存排布是否发生变化
此外,Tensor 以外的变量通常采取**一个变量,一个 Guard **的检查策略,而 Tensor 类型的数据则会进行集中检查,即所有 Tensor 变量只会生成一个检查函数:
___check_tensors
,该函数会遍历并检查所有 Tensor。如果任何检查失败了,都会导致重新编译
- 数据类型是否发生变化,例如原来数据类型为
经 dynamo 编译好的函数被保存在 Frame 的 cache 中,从而避免再次编译相同的函数和输入。默认情况下 cache 大小为 64,也就是说,对于同一个 Python 函数,它的输入最多可以有 64 种变化,超过这个限制后 dynamo 不再编译该函数
Graph Break
既然是动态编译,并且我们的做法是只检查模型的输入,而不是实际运行一遍代码后,再判断是否应该重新编译一遍函数,这是显然的,否则我们不如用 TensorFlow
,但这里带来的问题是,如果代码中存在控制流怎么办,例如:
这里的 b.sum()
会返回一个 Tensor,此时无论如何都没有办法仅凭输入去判断会走哪个分支。对于这种情况,dynamo 的做法是:Graph Break
dynamo 会把 toy_example()
拆分为 3 张子图,不能处理的 if
语句由 Python 解释器执行。编译后对应的 Python 函数如下,执行完编译好的子图 __compiled_fn_0()
后,程序返回到 Python 解释器,根据 if
语句的结果选择执行还未编译的子图 __resume_at_30_1()
或 __resume_at_38_2()
:
其中包含了 3 个函数:
-
__compiled_fn_0()
: dynamo 编译好的子图,对应if
语句前面的部分: -
__resume_at_30_1()
: dynamo 未编译的子图,对应if
分支 (dynamo 直接操纵字节码):该函数会在首次执行时被 dynamo 捕获并编译
-
__resume_at_38_2()
: dynamo 未编译的子图,对应else
分支,该函数也会在首次执行时被 dynamo 捕获并编译:其字节码对应如下:
循环展开
dynamo 把 Python 中的循环捕获为循环展开的计算图,即捕获的计算图中不再包含循环。例如下面的代码片段,其中的 for
循环迭代了 4 次、每次执行一次乘法操作:
捕获到的计算图对应的 Python 函数为:
这个过程的原理是 dynamo 在它的 Python 虚拟机模拟器中模拟运行了 FOR_ITER
这条字节码指令,然后捕获在每次迭代中出现的运算,而不是把 for
循环本身捕获到计算图中
注意,这并不是说后端生成的
kernel
就不存在for
循环语句了
内联函数
针对用户函数调用,dynamo 会尝试内联 (inline) 被调函数,从而生成更大的计算图。但如果被掉函数中存在 Graph Break,那么内联就会失败,此时函数调用栈中的每个函数都会产生一个 graph break
下面的代码片段中 test()
调用了递归函数 toy_example()
:
dynamo 在捕获 toy_example(x, 4)
的计算图时,会尝试内联 toy_example(x, 3)
的计算图,依次类推,直到成功内联 toy_example(x, 0)
的计算图。最终生成一个大的计算图,其中的函数调用被展开:
但在下面的代码片段中,用户函数 baz()
无法被内联,因为其中的 if
条件依赖于张量的值,只有在运行时才能确定执行哪个分支,故而存在一个 Graph Break。这个 Graph Break 导致其调用者 bar()
和 foo
都产生了子图,最后总共生成 7 个计算图:
dynamo 通过字节码指令 CALL_FUNCTION
实现内联函数,其中识别用户函数调用并尝试内联,内联失败时恢复主调函数的状态并创建子图,子图编译完后返回解释器执行子函数调用。
Distributed Data Parallel
不需要太过关注这部分
通过数据并行在多 GPU 上训练深度学习模型时,需要调用 allreduce 对所有 GPU 上的梯度进行规约。深度学习框架中往往都把一些参数的梯度放在一个 bucket 中,当这个 bucket 中的所有梯度都已经就绪后,就会使用 allreduce 进行梯度规约。
dynamo 捕获的计算图并不包含 DDP 的 hook 或者 allreduce 节点,如果整个模型被捕获为一张计算图,那么所有的 allreduce 都只能等到反向传播结束才能被触发,导致 allreduce 无法和反向传播 overlap。为了能够在一个 bucket 中的梯度就绪时及时调用 allreduce 进行通信,TorchDynamo 会在每个 bucket 的边界引入 graph break。
小结
总而言之,我们可以通过下图来形式化 dynamo
的执行过程:
得到了 FX IR,但我们还只是得到了正向的计算图,想要反向传播,我们必须得到反向图,这就需要 AOT Autograd
Ahead Of Time Auto Gradient (AOTAutograd
)
- 获取反向传播计算图
- 用不同的后端编译器分别编译正向传播和反向传播计算图
- 针对训练 (training) 做正向传播、反向传播联合优化,比如通过在反向传播中重算 (recompute) 来减少正向传播为反向传播保留的 tensor,从而削减内存需求;
通过这一步,计算图中的算子从 torch
转化到 ATen
算子,它们是 low-level 算子,而不是 Torch 级别的算子,例如 torch.sigmoid
会被下降为 torch.aten.ops.sigmoid.default()
为什么叫 AOTAutograd?因为 PyTorch 反向传播的计算图是在执行正向传播的过程中动态构建的,反向传播的计算图在正向传播结束时才能确定下来。AOTAutograd 以 Ahead-of-Time 的方式同时 trace 正向传播和反向传播,从而在函数真正执行之前拿到正向传播和反向传播的计算图。
总的来说,AOTAutograd 的工作流程 如下:
- 以 AOT 方式通过
__torch_dispatch__
机制 trace 正向传播和反向传播,生成联合计算图 (joint forward and backward graph),它是包含 Aten/Prim 算子的 FX Graph; - 用
partition_fn
把 joint graph 划分为正向传播计算图和反向传播计算图; - 通过
decompositions
把 high-level 算子分解、下沉到粒度更小的算子 (optional) - 调用
fw_compiler
和bw_compiler
分别编译正向传播计算图和反向传播计算图,通过TorchFX
生成编译后的 Python 代码,并整合为一个torch.autograd.Function
__torch_dispatch__
是什么
AOTAutograd 得以工作的核心是 __torch_dispatch__
**。**PyTorch 的核心是一个 dispatcher,它的功能是根据输入 tensor 的属性把算子 dispatch 到具体的 kernel 上,比如根据 tensor 的 device
属性决定是调用 CUDA kernel 还是 CPU 函数执行该算子。
一个算子在 PyTorch 中往往要经过多次 dispatch,__torch_dispatch__
给了用户提供了一个入口,使得用户能够在算子最终 dispatch 前获取对应的算子和输入。
这个过程我们可以简化为查表,例如:
我们现在有一个算子 mul
,其运算的 device
为 cuda,于是我们可以从这个表中找到其底层算子实现的函数指针,从而返回。
这些函数指针我们可以具体化为如下:
这里,每个aten
算子都可能有其对应设备的实现(如果没有的话就会退化,例如 cuda 没有就可能采用默认的算子),我们通过查这个表,从而获得底层通过 C++
实现的函数指针,然后将其返回,赋值给原来的算子,这样就能保证原来的算子可以去调用这个底层算子,从而提高运算效率。
例如 torch
中 Tensor
之间的 dot
运算,在这里就会通过 dispatch
映射到 aten::dot
上去,而这个是在底层 C++
实现的,例如:
当然,对 dispatch
的理解不需要太过深入,我觉得把它当作一个查表的过程就可以了。
而需要注意的是,这一步是多线程完成的,在 C++
中用了一个线程池,从而进行快速的映射,因此调试起来还是很困难的……
通过这一步,用户有机会在 kernel 执行前获取算子和参数,从而可以做很多事情,基于 __torch_dispatch__
的 tracing 正是其中之一。
去重
在之前我们提到了 Torch FX 实现了 make_fx
,与常规的 symbolic tracing 不同,make_fx
是通过 __torch_dispatch__
实现的,AOTAutograd 的 tracing 正是用的 make_fx
。以下面的代码为例:
我们可以得到如下结果:
而 symbolic tracing(也就是 FX 层面的 trace)得到的结果如下:
可以发现,FX 层得到的仍是 torch 级别的算子,而在 make_fx
中已经得到了 aten
的算子,但 make_fx
存在一些问题:当用于 trace 的输入参数中包含重复的 tensor 时,例如:
我们得到的结果如下:
可以发现得到的运算是 而不是 ,除此之外,如果我们使用 torch.autograd.grad(f(x, y), (x, y))
计算函数 f(x, y)
对 (x, y)
的梯度,但如果 x
和 y
是相同的 tensor,trace 出来的梯度就是错的。
存在这样的问题是因为因为使用 __torch_dispatch__
进行 tracing 时使用的是 tensor,而要建立的是 fx.Graph
,怎么把 tensor 映射到 fx.Graph
中的节点?答案是通过 tensor 的 ID(这里可以看作是哈希值),相同的 tensor 会被映射到 fx.Graph
中的同一个 Proxy,因而给被 trace 的函数实际参数去重就很有必要。
因此,AOTAutograd
会在 trace 开始前给函数去重,做法如下:
- 通过
detach
把待 trace 函数的重复参数变为 leaf tensor: 缺点是待 trace 函数不能改变重复参数,例如在重复 tensor 上调用 in-place 算子; - 把重复的参数从函数签名中移除: 捕获的计算图是针对重复参数特化的版本;
Joint Graph
有了基于 __torch_dispatch__
的 tracing 机制,AOTAutograd 就可以 trace 联合正向传播和反向传播计算图。这里的逻辑比较直接,如果用户要想优化的正向传播函数是 fn
,AOTAutograd 则构建并 trace 一个 joint_forward_backward
函数,其中调用正向传播函数 fn
之后,再调用 torch.autograd.grad
执行反向传播。
这里通过上述的 make_fx
来 trace joint_forward_backward
函数,注意,对于每个 算子 而言,都会触发 __torch_dispatch__
,直到遍历完所有的算子,得到一张完整的 joint_graph
Partition
AOTAutograd 用 partition_fn
把 joint graph 划分为正向传播计算图和反向传播计算图,目前内置了两种 partition_fn
:
- default_partition: 模拟了 PyTorch 的默认行为,找出从 forward 的输入到 forward 的输出的所有算子输出,其中被 backward 用到的 tensor 也作为 forward 的输出,是 forward 保留给 backward 的 tensor;
- min_cut_rematerialization_partition: 通过在 backward 中引入重算,减少 forward 给 backward 保留的 tensor,这种重算的思路与 gradient/activation checkpointing 一致;
显然,在划分图的时候需要考虑的问题很多,例如,我们应该按照什么标准来切分正向图和反向图,如何选择 forward 保留给 backward 的算子,其理由是什么。
一般采取的原则是 内存需求最少,这里切分使用的算法是最大流最小割:
- 在源节点 (source) 和 primals 之间各添加一条边,在所有的 tangent’s closure 和目标节点 (sink) 之间各添加一条边,它们组成了一张从 source 到 sink 的有向图,边上的权重是 tensor size;
- 我们需要找到一个合适的切分方法,把这个有向图分成两部分,使得 source 子图到 target 子图之间边上的权重之和最小,这是一个最小割问题;
- 最小割问题的对等问题是最大流问题,已经有标准的解法,直接在该有向图上运行 max-flow 算法即可得到最佳划分方法;
Decompose (Optional)
对于一些算子,我们会将其分解为细粒度的算子,例如 BN
,SiLU
等,都会分解。这里我们以 SiLU
为例:
此函数首先会被映射为 aten.silu
,随后,会进行分解到 aten.ops
级别:
可以看见 aten.silu
被分解为四步:两次转换,一次 sigmoid
和一次矩阵乘
(当然最下面的 copy
是另外一个操作,inplace
,但这里不需要太过在意)
总结
我们已经走完了前端的步骤,形式化如下:
我们举例如下,现在源代码为:
通过前端 lowering
后,我们得到 aten/prims
级别的算子如下:
接着就会进入后端,进行代码生成
Backend
后端的流程如下所示:
首先我们会做一次 lowering
,然后进行调度,最后才会生成 Triton
的 kernel
(这里以 Triton
为例)
Loop-level IR
对于这次的 lowering
,我们使用 loop-level IR
来表示,其对 aten IR
的每一句话做解释,并且每次的解析都会与前文联系起来
这一层 IR
的类型有:
PointWise
Reduction
TensorBox
MatrixMultiplyAdd
等,我们从下图中解释这一层是如何工作的:
首先,我们有前端拿到的 aten IR
:
对于上面的每一句运算,我们都会翻译为 loop-level IR
:
-
convert_element_type
: -
amax
:可以发现,这里将计算的结果存储到
buf0
中 -
sub
:由于
amax
将结果存储到buf0
中,因此这里才能从buf0
中直接load
进来 -
exp
:如果上一条 IR 是
pointwise
的话,那么就会和这一次的进行归约,例如这里,只是在sub
的IR
上加上了tmp4 = exp(tmp3)
并将return
改为了tmp4
因此,这一层的 pass
会对 aten IR
的每一句话进行解析,并且每次的解析都会与前文联系起来,最终得到一个归约的 IR
⇒
loop-level IR
对上面的每一句话进行解析,并且每次的解析都会与前文联系起来,最终得到一个完整的 IR
Schedule
若以先前的:
为例,那么其在 loop-level
层会构建出 个缓冲区。随后,会对这些缓冲区进行 schedule
,内容包括:
注意到,这里我们有些缓冲区启用了 Reduction
,也就是说这里的归约是对于缓冲区而言的,我们将这些缓冲区放在一起,生成一个 kernel
,而其他的缓冲区,则单独生成自己的 kernel
(注意这里的 kernel
是指 triton
的 kernel
,实际上我们可以认为是一个函数)
只有
reduction
的kernel
中会出现循环语句,若只是pointwise
的计算,则不会生成循环
悲哀的是关于如何调度缓冲区,如何分块,采取什么样的策略,为什么,本人都没有很清楚,因为代码写的实在是太太太复杂了……
Triton Kernel
最后就是 triton kernel
的生成了,其采取的策略是:
- 首先生成
load
语句 - 生成
compute
语句 - 生成
store
语句 - 组合三种语句为一个
kernel
- 组合所有
kernel
与一个call
函数和main
模块在一起为一个.py
文件
上述例子生成的文件如下:
此文件生成在 /tmp
文件夹中,后缀为 py
,后续直接运行此文件,可得到 performace
的值,同样,也可在运行中捕获到运算的值
我们设计的 Triton
生成策略如下所示:
事实上也是对 torch 的拙劣模仿而已,生成的代码也只是能跑的程度,而且适用性目前只有几种简单的 case
而已
一些问题
-
loop-level IR
是如何完成到triton kernel
的生成的通过数据结构
GraphLowering
的方法run(*example_input)
也就是一个Fake Tensor
进行生成:
-
调度的策略是什么样的,调度的目的又是什么
调度的目的:由于在前面已经进行了
decompose
(一般在转为aten
算子的时候就已经完成了),因此这里的目的是为了调整buff
的次序,也就是调度内存,以优化内存访问的效率。
Trace
记录
aten IR
到 loop-level IR
在 torch/_inductor/compile_fx.py
中 #179 完成的
其中,输入的 gm
中存储的 code
为:
生成的 loop-level IR
代码在之前已经给出了
随后,得到的 loop-level IR
会通过下一行的 compile_to_fn()
进行到 triton
的转化,生成的 triron
代码会存储在 /tmp/
目录下的一个 .py
文件中,例如这里的 /tmp/torchinductor_xuruiyuan/hi/chirjp335x7uz3omvvfgceqnzplwllaqcywzyumftwtnk4mc5wvo.py
返回的值是一个函数 compiled_fn
,其 __module__
变量存储着上述的文件路径,例如 'torch._inductor.codecache.chirjp335x7uz3omvvfgceqnzplwllaqcywzyumftwtnk4mc5wvo'
如何从 GraphLowering
生成 Triton
内核
注意,这里的
graph
为GraphLowering
类,定义在torch/_inductor/graph.py
中
-
首先,调用
graph.compile_to_fn()
-
这个函数会先去调用
graph
中的compiler_to_module()
,对此返回值,取出其call
属性并返回 -
对于
compiler_to_module()
,首先调用self.codegen()
来生成triton
代码(返回一个py
文件),随后将此代码重命名后返回 -
在
codegen
中,首先调用了self.init_wrapper_code()
,此函数只是检查是否需要使用cpp
包装,一般而言都不需要,于是实例化了一个WrapperCodeGen()
的对象并返回
Quote
WrapperCodeGen
定义在torch/——inductor/codegen/wrapper.py
中,初始化的代码如下:
-
接着,对
graph
中的scheduler
进行实例化,调度的对象为loop-level IR
中构造出的东西,实际上可以视为计算节点:实例化的过程如下:
-
声明一个空的
node
列表,用于新的构造 -
拿到后续计算所依赖的缓冲区名称,例如这里是:
-
遍历传入的参数列表,这里就是在之前传入的列表等,对于列表中的每一个元素,做如下操作:
-
查看此
node
是否存在入度(也就是数据是从什么地方来的,一般为缓冲区名称),例如convolution
的origin
为: -
对
node
的类型进行查看,当然,在这里由于传入的节点均为buffer
,因此不会进入is_no_op
函数事实上我们只对内存进行调度
接着,判断是否为
ComputedBuffer
或TemplateBuffer
,其中TemplateBuffer
给出的解释为Represents a Triton (in the futurue other type) of template operator that we can fuse an epilogue onto.
(显然,对于后续的ComputedBuffer
都会进入这一条分支,并执行self.get_backend(node.get_device()).group_fn
)对于卷积而言,在这里定义为
ExternKernel
,因此会生成特定的内核:而实际上还是生成了基础的调度节点,但增加了新的成员函数而言
-
将刚才生成的
node
添加到最开始创建的node
列表中去卷积
node
重新生成后如下:
-
我们会进入
get_backend(node.get_device()).group_fn
中-
注意这里的
get_device()
返回的均为device(type='cuda', index=0)
-
对于
get_backend()
而言,返回的是一个后端的字典,来方便后面查询是否支持此后端,最开始时显然字典为空,于是会调用create_backend(device)
来创建,并将其存储到self.backends
这个字典中 -
对于
create_backend()
,由于我们这里是cuda:0
并且安装了triton
,因此直接返回TritonScheduling
此类实际上相当于将
Scheduler
传递到了后面的步骤去,如下所示: -
接着,根据
backend
创建SchedulerNode
,注意这里提到了group_fn
,为: -
对于创建调度节点,代码如下:
在生成之前,我们会进行一次优化
simplify_and_reorder()
,在这里以一种与后端无关的方式进行循环转换,做法为:- 删除所有一维的部分
- 将连续的维度融合在一起
- 根据跨度顺序重新排列尺寸
例如这里的
在转化后,新增了如下两个部分:
可以发现,
ranges
从[2, 32, 6, 6]
融合为[2, 32, 36]
-
下一步,通过调用
group_fn
,得到了元组(2*32*36
,1
) -
接着,对读写的依赖进行操作并重写,重写的步骤如下:
- 调用
torch/_inductor/dependencies.py
中的extract_read_writes
进行循环变量的重写,例如这里,[2, 32, 36]
,声明了三个变量(d0, d1, d2)
来对应,并调用LoopBody
(也就是self._body
) 中的__call__
进行处理 - 最后,返回一个
ReadWrites
对象,其读写范围为上面修改后的内容
- 调用
-
最后查看是否需要
reduction
在buf4
时,触发了reduction
,buf4
的loop-level
如下:
-
-
做完调度后,接着就开始直接生成内核,注意,如果是特殊的算子(例如卷积)是不会被翻译为
triton
的,直接生成aten
,否则,我们会进入codegen_kernel
阶段,而是pointwise
还是reduction
,我们需要使用调度后融合的节点来完成这一点,例如,2,3,6,4,5,7 进行了融合,因此我们需要把这几个buff
生成到一个kernel
里面去