Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Onnx file loads very slowly with onnxruntime, which is exported (dynamic_shapes=True) with torch.onnx_dynamo_export. #6016

Open
Lingstreasure opened this issue Mar 12, 2024 · 2 comments
Labels
converters Issues related to ONNX converters question Questions about ONNX

Comments

@Lingstreasure
Copy link

Lingstreasure commented Mar 12, 2024

I trained an inpainting model which has torch.rfftn / torch.irfftn modules and accepts image data with shape-[b, 4, h, w]. For some reason the torch.onnx.export can't export operators with complex tenors. I tried to make dynamic export successfully with torch.onnx.dynamo_export, but it takes a long time for onnxruntime to load it, here is my model: onnx

environment:
os: Ubuntu 20.04.5 LTS
onnx==1.14.1
onnxruntime==1.16.0
onnxscript==0.1.0.dev20240304
torch==2.1.1+cu12.1+cudnn8.9.2

model.py:
# Fast Fourier Convolution NeurIPS 2020
# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf

import torch
import torch.nn as nn


class FourierUnit(nn.Module):
    def __init__(self, in_channels, out_channels, groups=1, fft_norm='ortho', 
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU):
        super(FourierUnit, self).__init__()
        self.groups = groups
        self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
        self.bn = norm_layer(out_channels * 2)
        self.relu = activation_layer(True)
        self.fft_norm = fft_norm

    def forward(self, x):
        batch, channel, h, w = x.shape

        fft_dim = (-2, -1)
        ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
        ffted = torch.stack((ffted.real, ffted.imag), dim=-1)          # (b, c, h, w/2+1, 2)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()              # (b, c, 2, h, w/2+1)
        ffted = ffted.view((batch, 2 * channel, h, -1))                # (b, 2c, h, w/2+1)

        ffted = self.conv_layer(ffted)  # (b, 2c, h, w/2+1)
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view((batch, channel, 2, h, -1)).permute(0, 1, 3, 4, 2).contiguous()     # (b, c, h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])            # (b, c, h, w/2+1)

        ifft_shape_slice = x.shape[-2:]
        output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)  # (b, c, h, w)
        return output


class SpectralTransform(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, groups=1,
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU, **fu_kwargs):
        super(SpectralTransform, self).__init__()
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels //
                      2, kernel_size=1, groups=groups, bias=False),
            norm_layer(out_channels // 2),
            activation_layer(True)
        )
        self.fu = FourierUnit(
            out_channels // 2, out_channels // 2, groups, norm_layer=norm_layer, 
            activation_layer=activation_layer, **fu_kwargs)
        self.conv2 = torch.nn.Conv2d(
            out_channels // 2, out_channels, kernel_size=1, groups=groups)

    def forward(self, x):
        x = self.downsample(x)
        x = self.conv1(x)
        output = self.fu(x)
        output = self.conv2(x + output)
        return output


class FFC(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin, ratio_gout, stride=1, padding=0,
                 dilation=1, groups=1, bias=False, 
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU,
                 padding_type='reflect', **spectral_kwargs):
        super(FFC, self).__init__()

        assert stride == 1 or stride == 2, "Stride should be 1 or 2."
        self.stride = stride

        in_cg = int(in_channels * ratio_gin)
        in_cl = in_channels - in_cg
        out_cg = int(out_channels * ratio_gout)
        out_cl = out_channels - out_cg

        self.ratio_gin = ratio_gin
        self.ratio_gout = ratio_gout
        self.global_in_num = in_cg

        module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
        self.convl2l = module(in_cl, out_cl, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
        self.convl2g = module(in_cl, out_cg, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
        self.convg2l = module(in_cg, out_cl, kernel_size,
                              stride, padding, dilation, groups, bias, padding_mode=padding_type)
        module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
        self.convg2g = module(
            in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, norm_layer=norm_layer, 
            activation_layer=activation_layer, **spectral_kwargs)

        module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d

    def forward(self, x):
        x_l, x_g = x if isinstance(x, tuple) else (x, torch.tensor(0.0))
        out_xl, out_xg = torch.tensor(0.0), torch.tensor(0.0)
        if self.ratio_gout != 1:
            out_xl = self.convl2l(x_l) + self.convg2l(x_g)
        if self.ratio_gout != 0:
            out_xg = self.convl2g(x_l) + self.convg2g(x_g)
        return out_xl, out_xg


class FFC_BN_ACT(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, ratio_gin, ratio_gout,
                 stride=1, padding=0, dilation=1, groups=1, bias=False,
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
                 padding_type='reflect', **kwargs):
        super(FFC_BN_ACT, self).__init__()
        self.ffc = FFC(in_channels, out_channels, kernel_size,
                       ratio_gin, ratio_gout, stride, padding, dilation,
                       groups, bias, norm_layer, activation_layer, 
                       padding_type=padding_type, **kwargs)
        lnorm = nn.Identity if ratio_gout == 1 else norm_layer
        gnorm = nn.Identity if ratio_gout == 0 else norm_layer
        global_channels = int(out_channels * ratio_gout)
        self.bn_l = lnorm(out_channels - global_channels)
        self.bn_g = gnorm(global_channels)

        lact = nn.Identity if ratio_gout == 1 else activation_layer
        gact = nn.Identity if ratio_gout == 0 else activation_layer
        self.act_l = lact(inplace=True)
        self.act_g = gact(inplace=True)

    def forward(self, x):
        x_l, x_g = self.ffc(x)
        x_l = self.act_l(self.bn_l(x_l))
        x_g = self.act_g(self.bn_g(x_g))
        return x_l, x_g


class FFCResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, dilation=1,
                 **conv_kwargs):
        super().__init__()
        self.conv1 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
                                norm_layer=norm_layer,
                                activation_layer=activation_layer,
                                padding_type=padding_type,
                                **conv_kwargs)
        self.conv2 = FFC_BN_ACT(dim, dim, kernel_size=3, padding=dilation, dilation=dilation,
                                norm_layer=norm_layer,
                                activation_layer=activation_layer,
                                padding_type=padding_type,
                                **conv_kwargs)

    def forward(self, x):
        x_l, x_g = x if type(x) is tuple else (x, 0)
        id_l, id_g = x_l, x_g

        x_l, x_g = self.conv1((x_l, x_g))
        x_l, x_g = self.conv2((x_l, x_g))

        x_l, x_g = id_l + x_l, id_g + x_g
        out = x_l, x_g
        return out


class ConcatTupleLayer(nn.Module):
    def forward(self, x):
        assert isinstance(x, tuple)
        x_l, x_g = x 
        assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
        if not torch.is_tensor(x_g):
            return x_l
        return torch.cat(x, dim=1)


class FFCResNetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
                 padding_type='reflect', activation_layer=nn.ReLU, up_norm_layer=nn.BatchNorm2d, 
                 up_activation=nn.ReLU(True), init_conv_kwargs={},  downsample_conv_kwargs={}, 
                resnet_conv_kwargs={}, add_out_act=True,  max_features=1024, out_ffc=False, out_ffc_kwargs={}):
        assert (n_blocks >= 0)
        super().__init__()
        model = [nn.ReflectionPad2d(3),
                 FFC_BN_ACT(input_nc, ngf, kernel_size=7, padding=0, norm_layer=norm_layer,
                            activation_layer=activation_layer, **init_conv_kwargs)]

        ### downsample
        for i in range(n_downsampling):
            mult = 2 ** i
            if i == n_downsampling - 1:
                cur_conv_kwargs = dict(downsample_conv_kwargs)
                cur_conv_kwargs['ratio_gout'] = resnet_conv_kwargs.get('ratio_gin', 0)
            else:
                cur_conv_kwargs = downsample_conv_kwargs
            model += [FFC_BN_ACT(min(max_features, ngf * mult),
                                 min(max_features, ngf * mult * 2),
                                 kernel_size=3, stride=2, padding=1,
                                 norm_layer=norm_layer,
                                 activation_layer=activation_layer,
                                 **cur_conv_kwargs)]

        mult = 2 ** n_downsampling
        feats_num_bottleneck = min(max_features, ngf * mult)

        ### resnet blocks
        for i in range(n_blocks):
            cur_resblock = FFCResnetBlock(feats_num_bottleneck, padding_type=padding_type, activation_layer=activation_layer,
                                          norm_layer=norm_layer, **resnet_conv_kwargs)
            model += [cur_resblock]
        
        model += [ConcatTupleLayer()]

        ### upsample
        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(min(max_features, ngf * mult),
                                         min(max_features, int(ngf * mult / 2)),
                                         kernel_size=3, stride=2, padding=1, output_padding=1),
                      up_norm_layer(min(max_features, int(ngf * mult / 2))),
                      up_activation]

        if out_ffc:
            model += [FFCResnetBlock(ngf, padding_type=padding_type, activation_layer=activation_layer,
                                     norm_layer=norm_layer, inline=True, **out_ffc_kwargs)]

        model += [nn.ReflectionPad2d(3),
                  nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        if add_out_act:
            model.append(nn.Sigmoid())
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

export.py:
import torch
from models import FFCResNetGenerator

if __name__ == "__main__":
    model = FFCResNetGenerator(
            input_nc=4,
            output_nc=3,
            ngf=64,
            n_downsampling=3,
            n_blocks=9,
            init_conv_kwargs={
                "ratio_gin": 0,
                "ratio_gout": 0,
            },
            downsample_conv_kwargs={
                "ratio_gin": 0,
                "ratio_gout": 0,
            },
            resnet_conv_kwargs={
                "ratio_gin": 0.75,
                "ratio_gout": 0.75,
            }
        )
    model.eval()
    
    input_data = torch.randn(1, 4, 512, 1024)
    args = (input_data,)
    export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
    torch.onnx.dynamo_export(
        model,
        *args,
        export_options=export_options,
    ).save("dynamic_fft.onnx")
    print(f"Dynamic onnx exported to dynamic_fft.onnx")
Generally, this code will raise an error when executed:
warnings.warn(
Traceback (most recent call last):
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 1397, in run_node
    return node.target(*args, **kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in dispatch
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 475, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten._fft_r2c.default


I modified some codes to make it pass, and successfully export a dynamic-model onnx. But it takes about 1min to load the dynamic-model onnx file with onnxruntime for inference, which is too slowly and can't be accept in my task.

For dynamic exporting, I modified several codes as follows:

  1. comment 2 line of codes in _subclasses/fake_tensor.py of torch (2.1.1) package, find the function stride_incorrect_op:

    def stride_incorrect_op(op):
        if op.namespace not in ("aten", "prims"):
            return False
        if op is aten._fft_c2c.default:
            return False
    
        op_name = op.name()
        # if "fft" in op_name:  ### comment the condition expr
        #     return True
        return False

    Then, execute the export.py will raise another error:

    Traceback (most recent call last):
      File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/exporter.py", line 1190, in dynamo_export
        return Exporter(
      File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/exporter.py", line 950, in export
        graph_module = pre_export_passes(
      File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/exporter.py", line 1250, in pre_export_passes
        analysis.UnsupportedFxNodesAnalysis(
      File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 74, in analyze
        self._lint(analysis_result, diagnostic_level)
      File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/fx/analysis/unsupported_nodes.py", line 38, in _lint
        self.diagnostic_context.log_and_raise_if_error(diagnostic)
      File "/home/d5/anaconda3/envs/test_fft/lib/python3.9/site-packages/torch/onnx/_internal/diagnostics/infra/context.py", line 367, in log_and_raise_if_error
        raise RuntimeErrorWithDiagnostic(diagnostic)
    torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.complex.default']}.

    I register the function complex() in my export.py:

    export.py:
    import onnxscript
    import torch
    from onnxscript import FLOAT, COMPLEX64
    from torch.onnx import register_custom_op_symbolic
    
    from model import FFCResNetGenerator
    
    
    def register_complex_for_torch_dynamo():
        from onnxscript.onnx_opset import opset18 as op
        custom_aten = onnxscript.values.Opset(domain="custom.aten", version=1)
    
        @onnxscript.script(custom_aten)
        def custom_aten_complex(
            real: FLOAT[1, "C", "H", "W"], 
            imag: FLOAT[1, "C", "H", "W"]
        ) -> COMPLEX64[1, "C", "H", "W", 2]:
            real = op.Unsqueeze(real, axes=[-1])
            imag = op.Unsqueeze(imag, axes=[-1])
            return op.Concat(real, imag, axis=-1)
    
        # register 'aten::complex'
        onnx_registry = torch.onnx.OnnxRegistry()
        onnx_registry.register_op(namespace="aten", op_name="complex", function=custom_aten_complex)
        print(f"aten::complex is supported by ONNX registry: \
            {onnx_registry.is_registered_op(namespace='aten', op_name='complex')}"
            )
        return onnx_registry
    
    
    if __name__ == "__main__":
        model = FFCResNetGenerator(
            input_nc=4,
            output_nc=3,
            ngf=64,
            n_downsampling=3,
            n_blocks=9,
            init_conv_kwargs={
                "ratio_gin": 0,
                "ratio_gout": 0,
            },
            downsample_conv_kwargs={
                "ratio_gin": 0,
                "ratio_gout": 0,
            },
            resnet_conv_kwargs={
                "ratio_gin": 0.75,
                "ratio_gout": 0.75,
            }
        )
        model.eval()
        
        input_data = torch.randn(1, 4, 512, 1024)
        args = (input_data,)
        export_options = torch.onnx.ExportOptions(
          onnx_registry=register_complex_for_torch_dynamo(),  ### add here
          dynamic_shapes=True
        )
        torch.onnx.dynamo_export(
            model,
            *args,
            export_options=export_options,
        ).save("dynamic_fft.onnx")
        print(f"Dynamic onnx exported to dynamic_fft.onnx")

  2. For dynamic shape inference, in function_libs/torch_lib/ops/fft.py of package onnxscript in virtual environment, I add a function _ifftn_onnx():

    _ifftn_onnx():
    @torch_op(
        "aten::_fft_c2r",
        trace_only=True,
        private=True,
        complex=True,
    )
    def _ifftn_onnx(
        self: TFloat, 
        dims: Sequence[int], 
        normalization: int, 
        last_dim_size: INT64
    ) -> TFloat:
        """Standard complex to real inverse FFT.
    
        Args:
            self: The input tensor.
            dims: The dimensions to apply FFT.
            normalization: The normalization mode.
            inverse: Whether to compute the inverse FFT.
            last_dim_size: The size of last dim
    
        Returns:
            The transformed tensor.
        """
        # my model inputs are images, which have shape: [batch, c, h, w] 
        # so in this function, the `self` tensor will have a shape: [batch, c, h, w or w/2+1, 2]
        
        # The 0-th dimension in ONNX DFT-17 is the batch dimension. We need to add a new
        # dimension at the beginning to represent the batch dimension.
        transformed = op.Unsqueeze(self, axes=[0])
    
        # Add 1 to account for the batch dimension when counting axes from the left
        new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
    
        for dim in new_dims[:-1]:
            transformed = op.DFT(transformed, axis=dim, inverse=True, onesided=False)
    
        # Torch computers one-sided FFT on the last dimension only.
        ######################################################################################
        # There is an error in DFT opeartor when `inverse` and `onesided` are both True: 
        # Op (DFT) [ShapeInferenceError] is_onesided and inverse attributes cannot be enabled 
        # at the same time
        #####################################################################################
        
        # **** custom irfft implementation ****
        # make conjugate for reverse RFFT
        # the output size of rfft will be x/2 + 1, so complete the conjugate part first.
        transformed_conj = transformed * op.Constant(value_floats=[1.0, -1.0])
    
        # flip the conjugate part
        transformed_conj = op.Transpose(transformed_conj, perm=[4, 0, 1, 2, 3, 5])
        sequence_len = op.CastLike(last_dim_size / 2 + 1, last_dim_size)
        sequence_lens = op.Expand(sequence_len, shape=[1])
        transformed_conj = op.ReverseSequence(
            transformed_conj, 
            batch_axis=1,
            time_axis=0, 
            sequence_lens=sequence_lens
        )
        transformed_conj = op.Transpose(transformed_conj, perm=[1, 2, 3, 4, 0, 5])
        
        # slice out the needed part
        # my input `self` tensor sizes are always evens. 
        starts = op.Constant(value_ints=[0, 0, 0, 0, 1, 0])
        transformed_conj = op.Slice(
            transformed_conj, starts=starts, ends=op.Shape(transformed)
        )
        
        # concatenate with original positive part
        transformed = op.Concat(transformed, transformed_conj, axis=new_dims[-1])
        transformed = op.DFT(
            transformed, last_dim_size, axis=new_dims[-1], inverse=True, onesided=False
        )
    
        # Remove the batch dimension                         
        transformed = op.Squeeze(transformed, axes=[0])
        
        ### Normalize the result. The followed code will raise error, I implement normalization in my model.
        # ifft of DFT in ONNX has already normed with 1/n (test for sure), so we should `*n` first if `forward` is False
        # Reference https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn
        # Reference https://github.com/pytorch/pytorch/blob/d090c18fcaaba6e1b5cb474a89058cf6081c8275/torch/_refs/fft.py#L42
        # total_sample_count = last_dim_size
        # for dim_ in dims[:-1]:
        #     total_sample_count = total_sample_count * self_shape[dim_]#op.Constant(value_int=self_shape[dim_])
        # total_sample_count = op.CastLike(total_sample_count, transformed)
        
        # if normalization == 1:
        #     # "ortho" - normalize by 1/sqrt(n)
        #     transformed = op.Mul(transformed, op.Sqrt(total_sample_count))
        # elif normalization == 2:
        #     # "forward" - normalize by 1/n
        #     transformed = op.Mul(transformed, total_sample_count)
        return transformed
    reference:

    func: _fftn_onnx in function_libs/torch_lib/ops/fft.py of onnxscipt package

    numpy


    Then, find the function aten__fft_c2r() , replace the original implementation.

    @torch_op("aten::_fft_c2r", trace_only=True, complex=True)
    def aten__fft_c2r(
        self: TFloat,
        dim: Sequence[int],
        normalization: int,
        last_dim_size: INT64,  # pylint: disable=unused-argument
    ) -> TFloat:
        """_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor
    
        Complex to real inverse FFT.
        """
    
        self_rank = len(self.shape)
        # ONNX DFT input assumes the last dimension is the complex dimension.
        # Thus dim=-1 in PyTorch is dim=-2 in ONNX.
        dim = [(d - 1) + self_rank if d < 0 else d for d in dim]
        # transformed = _fftn_onnx(self, dim, normalization, inverse=True, onesided=True)  ### comment this line
        transformed = _ifftn_onnx(self, dim, normalization, last_dim_size=last_dim_size)   ### add this one
        # Take only the real part
        real_part = op.Slice(transformed, axes=[-1], starts=[0], ends=[1])
    
        return op.Squeeze(real_part, axes=[-1])

  3. Last, for the sake of the correct result, I have to finish the normalization of _ifftn_onnx() (not finished in 2.) in my model:

    in FourierUnit of model.py:
    class FourierUnit(nn.Module):
        def __init__(self, in_channels, out_channels, groups=1, fft_norm='ortho', 
                     norm_layer=nn.BatchNorm2d, activation_layer=nn.ReLU):
            super(FourierUnit, self).__init__()
            self.groups = groups
            self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
                                              kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
            self.bn = norm_layer(out_channels * 2)
            self.relu = activation_layer(True)
            self.fft_norm = fft_norm
    
        def forward(self, x):
            batch, channel, h, w = x.shape
    
            fft_dim = (-2, -1)
            ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho")          ### set to `ortho`
            ffted = torch.stack((ffted.real, ffted.imag), dim=-1)          # (b, c, h, w/2+1, 2)
            ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()              # (b, c, 2, h, w/2+1)
            ffted = ffted.view((batch, 2 * channel, h, -1))                # (b, 2c, h, w/2+1)
    
            ffted = self.conv_layer(ffted)  # (b, 2c, h, w/2+1)
            ffted = self.relu(self.bn(ffted))
    
            ffted = ffted.view((batch, channel, 2, h, -1)).permute(0, 1, 3, 4, 2).contiguous()     # (b, c, h, w/2+1, 2)
            ffted = torch.complex(ffted[..., 0], ffted[..., 1])            # (b, c, h, w/2+1)
    
            ifft_shape_slice = x.shape[-2:]
            output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm)  # (b, c, h, w)
            ### I use "ortho" normalization through my model.
            output = output * torch.sqrt(torch.tensor(h * w, requires_grad=False))  # add this line
            return output

After the 3 steps, I successfully export my dynamic model, but it is very slowly using onnxruntime for inference when execute ort.InferenceSession(onnx_file, providers=['CPUExecutionProvider']). I don't know how to handle this, here is the visualization of my onnx file:

visualization:

rfftn-irfftn onnx


It has a big subgraph in it due to torch._dynamo ? Maybe it's the reason why onnxruntime loading the onnx file so slowly? Would anyone give some help?

@Lingstreasure Lingstreasure added the question Questions about ONNX label Mar 12, 2024
@Lingstreasure Lingstreasure changed the title Onnx file loads very slowly with onnxruntime, which is exported (dynamic_shapes=True) from a model with torch.rfftn and torch.irfftn modules. [ONNX] Onnx file loads very slowly with onnxruntime, which is exported (dynamic_shapes=True) from a model with torch.rfftn and torch.irfftn modules. Mar 12, 2024
@Lingstreasure Lingstreasure changed the title [ONNX] Onnx file loads very slowly with onnxruntime, which is exported (dynamic_shapes=True) from a model with torch.rfftn and torch.irfftn modules. [ONNX] Onnx file loads very slowly with onnxruntime, which is exported (dynamic_shapes=True) with torch.onnx_dynamo_export. Mar 14, 2024
@liqunfu liqunfu added the converters Issues related to ONNX converters label Mar 16, 2024
@justinchuby
Copy link
Contributor

The models produced by dynamo_export are known to run slowly (for now) because they are unoptimized. The api is in beta, and we intend to provide tools to optimize these models for onnxruntime soon.

@gramalingam
Copy link
Contributor

Is your concern about model-loading time or inference run-time? It may help to also report this in the onnxruntime repo and/or the pytorch exporter repo. As Justin says above, the transition to dynamo-exporter is in progress (and these concerns should be addressed soon).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
converters Issues related to ONNX converters question Questions about ONNX
Projects
None yet
Development

No branches or pull requests

4 participants