We provide several tools to support model conversion.
The PyTorch neural network is written in python that eases the development of the algorithm. But the use of Python control flow and third-party libraries make it difficult to export the network to an intermediate representation. We provide a 'monkey patch' tool to rewrite the unsupported function to another one that can be exported. Here is an example:
from mmdeploy.core import FUNCTION_REWRITER
@FUNCTION_REWRITER.register_rewriter(
func_name='torch.Tensor.repeat', backend='tensorrt')
def repeat_static(input, *size):
ctx = FUNCTION_REWRITER.get_context()
origin_func = ctx.origin_func
if input.dim() == 1 and len(size) == 1:
return origin_func(input.unsqueeze(0), *([1] list(size))).squeeze(0)
else:
return origin_func(input, *size)
It is easy to use the function rewriter. Just add a decorator with arguments:
func_name
is the function to override. It can be either a PyTorch function or a custom function. Methods in modules can also be overridden by this tool.backend
is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.
The arguments are the same as the original function, except a context ctx
as the first argument. The context provides some useful information such as the deployment config ctx.cfg
and the original function (which has been overridden) ctx.origin_func
.
If you want to replace a whole module with another one, we have another rewriter as follows:
@MODULE_REWRITER.register_rewrite_module(
'mmagic.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
class SRCNNWrapper(nn.Module):
def __init__(self,
module,
cfg,
channels=(3, 64, 32, 3),
kernel_sizes=(9, 1, 5),
upscale_factor=4):
super(SRCNNWrapper, self).__init__()
self._module = module
module.img_upsampler = nn.Upsample(
scale_factor=module.upscale_factor,
mode='bilinear',
align_corners=False)
def forward(self, *args, **kwargs):
"""Run forward."""
return self._module(*args, **kwargs)
def init_weights(self, *args, **kwargs):
"""Initialize weights."""
return self._module.init_weights(*args, **kwargs)
Just like function rewriter, add a decorator with arguments:
module_type
the module class to rewrite.backend
is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.
All instances of the module in the network will be replaced with instances of this new class. The original module and the deployment config will be passed as the first two arguments.
The mappings between PyTorch and ONNX are defined in PyTorch with symbolic functions. The custom symbolic function can help us to bypass some ONNX nodes which are unsupported by inference engine.
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
def squeeze_default(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
return g.op('Squeeze', self, axes_i=dims)
The decorator arguments:
func_name
The function name to add symbolic. Use full path if it is a customtorch.autograd.Function
. Or just a name if it is a PyTorch built-in function.backend
is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.is_pytorch
True if the function is a PyTorch built-in function.arg_descriptors
the descriptors of the symbolic function arguments. Will be feed totorch.onnx.symbolic_helper._parse_arg
.
Just like function rewriter, there is a context ctx
as the first argument. The context provides some useful information such as the deployment config ctx.cfg
and the original function (which has been overridden) ctx.origin_func
. Note that the ctx.origin_func
can be used only when is_pytorch==False
.