Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Export PyTorch to ONNX #155

Open
saikrishna-pallerla opened this issue Dec 20, 2020 · 8 comments
Open

Export PyTorch to ONNX #155

saikrishna-pallerla opened this issue Dec 20, 2020 · 8 comments
Labels
enhancement New feature or request

Comments

@saikrishna-pallerla
Copy link

I am trying to get the model trained on a custom dataset using PyTorch framework exported to ONNX and to further convert to TensorRT and run on jetson nano. However, I am unable to convert the model to ONNX. Below is the code I am using:

Capture

But, it throws the below error
Capture1

Can someone help me understand the issue and help fix it? @rwightman It would be great if you can provide guidance here

@saikrishna-pallerla saikrishna-pallerla added the bug Something isn't working label Dec 20, 2020
@rwightman rwightman added the enhancement New feature or request label Dec 20, 2020
@rwightman
Copy link
Owner

This isn"t a bug, it"s just functionality not implemented since it"s non-trivial. See #89 and #32 ... I"ll leave this one open so another issue isn"t created.

I have no plans to tackled this in the near future, it would not be a learning experience for me. Others have asked, nobody has offered any help or code. If someone gets it working with included demo export & inference script I"d accept a PR.

@rwightman rwightman removed the bug Something isn't working label Dec 20, 2020
@Ekta246
Copy link

Ekta246 commented Dec 23, 2020

I am trying to get the model trained on a custom dataset using PyTorch framework exported to ONNX and to further convert to TensorRT and run on jetson nano. However, I am unable to convert the model to ONNX. Below is the code I am using:

Capture

But, it throws the below error
Capture1

Can someone help me understand the issue and help fix it? @rwightman It would be great if you can provide guidance here

Hi,
Maybe you could try opset = 11
Also, why do you give pretrained_backbone=False?
Moreover, have you already passed that phase of getting rid of Swish activation operator not supported by onnx?

@saikrishna-pallerla
Copy link
Author

Hi,
Maybe you could try opset = 11
Also, why do you give pretrained_backbone=False?
Moreover, have you already passed that phase of getting rid of Swish activation operator not supported by onnx?

I have tried using opset 11 as well but didn"t work. Do you think it is the Swish operator that is causing this error in my execution where it says "_is" operator is not supported by ONNX currently? Or is it something else?

Error log doesn"t display which method or logic failed converting to ONNX. Is there a way I can know that so that I can bypass it by rewriting the logic in a way acceptable to ONNX? @Ekta246 @rwightman

Btw, I have also tried changing config.act_type to both "silu" and "relu" as against "swish". Neither of these ways helped

@Ekta246
Copy link

Ekta246 commented Jan 13, 2021

Just a heads up!
Get rid of memory efficient/AutoJit functions while looking for converting to ONNX!
You might change(temporary Workaround) some functions in the site-packages/timm/models/layers/activations_me.py Line 54 instead of returning SwishJitAutoFn.apply(x), you might write x*torch.sigmoid(x) simply (again, only while converting to onnx)

  1. You might also add exportable=True flag/argument in create_model function of the backbone(efficientnet)
  2. You might also need some modifications in the F.interpolate function anyway.

@ghavart
Copy link

ghavart commented Aug 5, 2021

Hi @saikrishna-pallerla Any success since your last post? I"ve run into the same problem. Another user said he"s seen success with non-tf backbones but so far it did not help me.

@fujikosu
Copy link

I got exact issue as raised for ONNX conversion like below. But I could partially make it work so I"m just leaving my learning here in case this becomes useful for someone else.

RuntimeError: Exporting the operator __is_ to ONNX opset version 13 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

What I was trying to export was DetBenchPredict with efficientdet_d1 architecture. This below is the snippet of my conversion code. This link was mentioning that we need to pass dictionary when a target model"s forward method has keyword argument so I added {"img_info": None} as input but error message didn"t change.

from effdet import get_efficientdet_config, EfficientDet, DetBenchPredict
config = get_efficientdet_config("efficientdet_d1")                                                                                                               
net = EfficientDet(config, pretrained_backbone=True)                                                                                                                
net = DetBenchPredict(net)
net.eval()                                                                                                                                   
torch.onnx.export(net.cuda(),                                # model being run
                  (torch.randn(1, 3, 512, 512).cuda(), {"img_info": None}),    # model input (or a tuple for multiple inputs)
                  "effdet_all.onnx",           # where to save the model (can be a file or file-like object)
                  input_names = ["input"],              # the model"s input names
                  output_names = ["output"],
                  opset_version=13, verbose=True)            # the model"s output names

I used torch==1.9.1

Next, I narrowed down the conversion target to figure out where the error is coming from. I extracted EfficientDet part of the model and applied conversion and succeeded.

model = net.model
model.eval()
torch.onnx.export(model.cuda(),                                # model being run
                  torch.randn(1, 3, 512, 512).cuda(),    # model input (or a tuple for multiple inputs)
                  "effdet_modelpart.onnx",           # where to save the model (can be a file or file-like object)
                  input_names = ["input"],              # the model"s input names
                  output_names = ["output"],
                  opset_version=13, verbose=True)            # the model"s output names

I saw that other people were reporting that type of activation layer affects ONNX exportability but I didn"t find a way to specify exportable=True for non-backbone part of EfficientDet so I wondered if that causes any error but it didn"t for my case. I saw that get_act_fn in create_act.py in timm converts SiLU to Swish to keep a model ONNX convertible. It seems as of pytorch 1.9.1 (what I used), the conversion of SiLU is supported.

    if is_exportable() and name in ("silu", "swish"):
        # FIXME PyTorch SiLU doesn"t ONNX export, this is a temp hack
        return swish

From this result, it"s clear that error is coming from post processing stage in forward method of DetBenchPredict. Given error message included something around __is_, I suspected is is somewhat causing conversion issue, which is weird as I see in official doc, is is supported in ONNX. But, I just commented out all lines of if-else that has is. Those span across forward method of DetBenchPredict, and generate_detections. All of them are related to img_info. With those commented out, I succeeded ONNX conversion. But without box scaling etc in ONNX I assume (haven"t tested).

So current potential workaround would be:

  • Convert only EfficientDet part of model to ONNX and still perform post processing out of ONNX
  • Move box scaling related operations out of DetBenchPredict and convert it to ONNX and perform box scaling as post processing after ONNX inference.

I"m still new to both ONNX and this codebase so I haven"t given enough thoughts onto which is easier though. But I assume EfficientDet part is most computation heavy part of the inference so being able to convert that part to ONNX at least was great to know for me.

@achbogga
Copy link

achbogga commented Jun 1, 2023

Has anyone been able to successfully convert "tf_efficientnetv2_s" from timm to Onnx?
Please let me know

@artemisart
Copy link

I have (hopefully) working onnx exports detailed here #302

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants