Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get dynamic batch size in Onnx model from Pytorch?

Tags:

pytorch

onnx

I have a currently working PyTorch to Onnx conversion process that I would like to enable a dynamic batch size for.

I am setting the dynamic axes like this:

    model.to_onnx(
        onnx_filepath,
        input_sample=(x, None),
        export_params=True,
        opset_version=17,
        do_constant_folding=True,
        input_names=['image'],
        output_names=['boxes', 'scores', 'labels'],
        dynamic_axes={
            "image": [0],
            "boxes": [0],
            "scores": [0],
            "labels": [0],
        }
    )

However, the onnx model that is produced only shows dynamic batch size for the input:

output from netron

How do I make the output batch size dynamic? Do I need to modify the forward function of the model in some way?

like image 631
brad Avatar asked Oct 11 '25 12:10

brad


1 Answers

This is just a guess, but are you by any chance processing each input image (or alternatively post-processing detections) of the batch separately inside of a for-loop? If yes, your behaviour might be due to how torch exports to ONNX, and you will need to modify your forward pass. Or, alternatively you can use torch.jit.script.

Where forwad pass could go wrong

Check your model for anything that defines a dimension of a tensor that is interpreted as a python integer during export. Setting dynamic axes will try to use variable shapes for the corresponding tensors, but will be overridden by explicit constant ones.

# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
    bs, c, h, w = batch.shape
    # bs is saved as a constant integer during export
    for i in range(bs):
        do_something()
# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
    # iterating over tensors is not supported for dynamic batch sizes
    # ONNX model will iterate the same amount as in batch during export
    for i in batch:
        do_something()
Potential fixes

Use tensor.size instead of tensor.shape

# CORRECT - WILL EXPORT WITH DYNAMIC AXES
def forward(self, batch):
    # This calls a function instead of getting an attribute,
    # the variable will be dynamic
    bs = batch.size(0)
    for i in range(bs):
        do_something()

Script parts of the model to preserve control flows and different input sizes

# CORRECT - WILL EXPORT WITH DYNAMIC AXES
# Script parts of the forward pass, e.g. single functions
@torch.jit._script_if_tracing
def do_something(batch):
    for i in batch:
        do_something_else()
    
def forward(self, batch):
    # function will be scripted, dynamic shapes preserved
    do_something(batch)

Export the whole module as a ScriptModule, preserving all control flows and input sizes

# CORRECT - WILL EXPORT WITH DYNAMIC AXES
script_module = torch.jit.script(model)
torch.onnx.export(
    script_module,
    ...
)
like image 99
simeonovich Avatar answered Oct 14 '25 10:10

simeonovich



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!