PyTorch支持两种模式:eager模式和script模式。eager模式主要用于模型的编写、调试和训练,script模式主要用于模型的部署。我们使用torch.onnx.export导出onnx的过程,也是对模型做script模式转换的过程,这里面主要设计两个函数:torch.jit.tracetorch.jit.script。这两个函数都是将python代码转换为TorchScript的两种不同的方法。

torch.jit.trace将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个PyTorch模型,torch.jit.trace会跟踪此input在model中的计算过程,然后将其转换为Torch脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。

torch.jit.script直接将Python函数(或者一个Python模块)通过python语法规则和编译转换为Torch脚本。torch.jit.script更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于RNN或者一些具有可变序列长度的模型,使用torch.jit.script会更为方便。

在通常情况下,更应该倾向于使用torch.jit.trace而不是torch.jit.script

在模型部署方面,onnx被大量使用。而导出onnx的过程,也是model进行torch.jit.trace的过程,因此这里我们把torch的trace做稍微详细一点的介绍。

为了能够把模型编写的更能够被jit trace,需要对代码做一些妥协,例如:

1.如果model中有DataParallel的子模块,或者model中有将tensors转换为numpy arrays,或者调用了opencv的函数等,这种情况下,model不是一个正确的在单个设备上、正确连接的graph,这种情况下,不管是使用torch.jit.script还是torch.jit.trace都不能trace出正确的TorchScript来。

2.model的输入输出应该是Union[Tensor, Tuple[Tensor], Dict[str, Tensor]]的类型,而且在dict中的值,应该是同样的类型。但是对于model中间子模块的输入输出,可以是任意类型,例如dicts of Any, classes, kwargs以及python支持的都可以。对于model输入输出类型的限制是比较容易满足的,在Detectron2中,有类似的例子:

outputs = model(inputs)   # inputs和outputs是python的类型, 例如dictsor classes
# torch.jit.trace(model, inputs)  # 失败!trace只支持Union[Tensor,Tuple[Tensor], Dict[str, Tensor]]类型
adapter = TracingAdapter(model, inputs)  # 使用Adapter,将modelinputs包装为trace支持的类型
traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # 现在以trace成功

# Traced model的输出只能是tuple tensors类型:
flattened_outputs = traced(*adapter.flattened_inputs)
# 再通过adapter转换为想要的输出类型
new_outputs = adapter.outputs_schema(flattened_outputs)

3.一些数值类型的问题。比如下面的代码片段

import torch
a=torch.tensor([1,2])
print(type(a.size(0)))
print(type(a.size()[0]))
print(type(a.shape[0]))

在eager mode下,这几个返回值的类型都是int型。上面代码的输出为

<class 'int'>
<class 'int'>
<class 'int'>

但是在trace mode下,这几个表达式的返回值类型都是Tensor类型。因此,有些表达式使用不当,如果在trace过程中,一些shape表达式的返回值类型是int型,那么可能造成这块代码没有被trace。在代码中,可以通过使用torch.jit.is_tracing来检查这块代码在trace mode下有没有被执行。

4.由于动态的control flow,造成模型没有被完整的trace。看下面的例子:

import torch

def f(x):
    return torch.sqrt(x) if x.sum() > 0 else torch.square(x)

m = torch.jit.trace(f, torch.tensor(3))
print(m.code)

输出为

def f(x: Tensor) -> Tensor:
  return torch.sqrt(x)

可以看到trace后的model只保留了一条分支。因此由于输入造成的dynamic的control flow,trace后容易出现错误。

这种情况下,我们可以使用torch.jit.script来进行TorchScript的转换。

import torch

def f(x):
    return torch.sqrt(x) if x.sum() > 0 else torch.square(x)

m = torch.jit.script(f)
print(m.code)

输出为

def f(x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = torch.sqrt(x)
  else:
    _0 = torch.square(x)
  return _0

在大多数情况下,我们应该使用torch.jit.trace,但是像上面的这种dynamic control flow的情况,我们可以混合使用torch.jit.tracetorch.jit.script,在本文后面会进行阐述。

另外在一些blog中,对于dynamic control flow的定义是有错误的,例如if x[0] == 4: x += 1是dynamic control flow,但是

model: nn.Sequential = ...
for m in model:
  x = m(x)

以及

class A(nn.Module):
  backbone: nn.Module
  head: Optiona[nn.Module]
  def forward(self, x):
    x = self.backbone(x)
    if self.head is not None:
        x = self.head(x)
    return x

都不是dynamic control flow。dynamic control flow是由于对输入条件的判断造成的不同分支的执行。

5.trace过程中,将变量trace成了常量。看下面一个例子

import torch
a, b = torch.rand(1), torch.rand(2)

def f1(x): return torch.arange(x.shape[0])
def f2(x): return torch.arange(len(x))

print(torch.jit.trace(f1, a)(b))
# 输出: tensor([0, 1])
# 可以看到trace后的model是没问题的,这里使用变量a作为torch.jit.trace的example input,然后将转换后的TorchScript用变量b作为输入,正常情况下,b的shape是2维的,因此返回值是tensor([0,1])是正确的

print(torch.jit.trace(f2, a)(b))
# 输出:
# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
# tensor([0])
# 可以看到这个输出结果是错误的,b的维度是2维,输出应该是tensor([0,1]),这里torch.jit.trace也提示了,使用len可能会造成不正确的trace。

# 我们打印一下两者的区别
print(torch.jit.trace(f1, a).code, '\n',torch.jit.trace(f2, a).code)
# 输出
# def f1(x: Tensor) -> Tensor:
#   _0 = ops.prim.NumToTensor(torch.size(x, 0))
#   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
#   return _1

#  def f2(x: Tensor) -> Tensor:
#   _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
#   return _0

# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

# 从trace的code中可以看出,使用x.shape这种方式,在trace后的code里面,是有shape的一个变量值存在的,但是直接使用len这种方式,trace后的code里面,就直接是1

我们导出onnx的过程,也是进行torch.jit.trace的过程,在导出onnx的时候,有时候也会遇到

TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

这样的提示信息,这时候要检查一下代码中是不是有可能trace过程中,变量会被当做常量的情况,有可能会导致导出的onnx精度异常。

除了len会导致trace错误,其他几个也会导致trace出现问题:

  • .item()会在trace过程中将tensors转为int/float

  • 任何将torch类型转为numpy/python类型的代码

  • 一些有问题的算子,例如advanced indexing

  1. torch.jit.trace不会对传入的device生效
import torch
def f(x):
    return torch.arange(x.shape[0], device=x.device)
m = torch.jit.trace(f, torch.tensor([3]))
print(m.code)
# 输出
# def f(x: Tensor) -> Tensor:
#   _0 = ops.prim.NumToTensor(torch.size(x, 0))
#   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
#   return _1
print(m(torch.tensor([3]).cuda()).device)
# 输出:device(type='cpu')

trace不会对传入的cuda device生效。

为了保证trace的正确,我们可以通过一下的一些方法来尽量保证trace后的模型不会出错:

1.注意warnings信息。类似这样的TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results. TraceWarnings信息,它会造成模型的结果有可能不正确,但是它只是个warning等级。

2.做单元测试。需要验证一下eager mode的模型输出与trace后的模型输出是否一致。

assert allclose(torch.jit.trace(model, input1)(input2), model(input2))

3.避免一些特殊的情况。例如下面的代码

if x.numel() > 0:
  output = self.layers(x)
else:
  output = torch.zeros((0, C, H, W))  # 会创建一个空的输出

避免一些特殊情况比如空的输入输出之类的。

4.注意shape的使用。前面提到,tensor.size()在trace过程中会返回Tensor类型的数据,Tensor类型会在计算过程中被添加到计算图中,应该避免将Tensor类型的shape转为了常量。主要注意以下两点:

  • 使用torch.size(0)来代替len(tensor),因为torch.size(0)返回的是Tensorlen(tensor)返回的是int。对于自定义类,实现一个.size方法或者使用.__len__()方法来代替len(),例如这个例子
  • 不要使用int()或者torch.as_tensor来转换size的类型,因为这些操作也会被视为常量。

5.混合tracing和scripting方法。可以使用torch.jit.script来转换一些torch.jit.trace不能搞定的小的代码片段,混合使用tracing和scripting,基本可以解决所有的问题。

混合使用tracing和scripting

tracing和scripting都有他们的问题,混合使用可以解决大部分问题。但是为了尽可能减小对于代码质量的负面影响,大部分情况下,都应该使用torch.jit.trace,必要时才使用torch.jit.script

1.在使用torch.jit.trace时,使用@script_if_tracing装饰器可以让被装饰的函数使用scripting方式进行编译。

def forward(self, ...):
  # ... some forward logic
  @torch.jit.script_if_tracing
  def _inner_impl(x, y, z, flag: bool):
      # use control flow, etc.
      return ...
  output = _inner_impl(x, y, z, flag)
  # ... other forward logic

但是使用@script_if_tracing时,需要保证函数中没有pytorch的modules,如果有的话,需要做一些修改,例如下面的:

# 因为代码中有self.layers(),是一个pytorch的module,因此不能使用@script_if_tracing
if x.numel() > 0:
  x = preprocess(x)
  output = self.layers(x)
else:
  # Create empty outputs
  output = torch.zeros(...)

这里需要做如下修改:

# 需要将self.layers移出if判断,这时候可以用@script_if_tracing
if x.numel() > 0:
  x = preprocess(x)
else:
  # Create empty inputs
  x = torch.zeros(...)
# 需要将self.layers()修改为支持empty的输入,或者将原先的条件判断加入到self.layers中
output = self.layers(x)

2.合并多次tracing的结果

使用torch.jit.script生成的模型相比使用torch.jit.trace有两个好处:

  • 可以使用条件控制流,例如模型中使用一个bool值来控制forward的flow,在traced modules里面是不支持的
  • 使用traced module,只能有一个forward()函数,但是使用scripted module,可以有多个前向计算的函数
class Detector(nn.Module):
  do_keypoint: bool

  def forward(self, img):
      box = self.predict_boxes(img)
      if self.do_keypoint:
          kpts = self.predict_keypoint(img, box)

  @torch.jit.export
  def predict_boxes(self, img): pass

  @torch.jit.export
  def predict_keypoint(self, img, box): pass

对于这种有bool值的控制流,除了使用script,还可以多次进行trace,然后将结果合并。

det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)

然后将他们的weight复制一遍,并合并两次trace的结果

det2.submodule.weight = det1.submodule.weight
class Wrapper(nn.ModuleList):
  def forward(self, img, do_keypoint: bool):
    if do_keypoint:
        return self[0](img)
    else:
        return self[1](img)
exported = torch.jit.script(Wrapper([det1, det2]))

trace和script的性能

tracing总是会比scripting生成一样或者更简单的计算图,因此性能会更好一些。因为scripting会完整的表达python代码的逻辑,甚至一些不必要的代码也会如实表达。例如下面的例子:

class A(nn.Module):
  def forward(self, x1, x2, x3):
    z = [0, 1, 2]
    xs = [x1, x2, x3]
    for k in z: x1 += xs[k]
    return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   z = [0, 1, 2]
#   xs = [x1, x2, x3]
#   x10 = x1
#   for _0 in range(torch.len(z)):
#     k = z[_0]
#     x10 = torch.add_(x10, xs[k])
#   return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   x10 = torch.add_(x1, x1)
#   x11 = torch.add_(x10, x2)
#   return torch.add_(x11, x3)

总结

tracing具有明显的局限性:这篇文章的大部分篇幅都在谈论tracing的局限性以及如何解决这些问题。实际上,这正是tracing的优势所在:它有明确的局限性(和解决方案),因此你可以推理它是否有效。

相反,scripting更像是一个黑盒子:在尝试之前,没有人知道它是否有效。文章中没有提到如何修复scripting的任何诀窍:有很多诀窍,但不值得你花时间去探究和修复一个黑盒子。

tracing和scripting都会影响代码的编写方式,但tracing因为我们明确它的要求,对我们原始的代码造成的一些修改也不会太严重:

  • 它限制了输入/输出格式,但仅限于最外层的模块。(如上所述,这个问题可以通过一个wrapper解决)。
  • 它需要修改一些代码才能通用(例如在tracing时添加一些scripting),但这些修改只涉及受影响模块的内部实现,而不是它们的接口。