I have a currently working PyTorch to Onnx conversion process that I would like to enable a dynamic batch size for.
I am setting the dynamic axes like this:
model.to_onnx(
onnx_filepath,
input_sample=(x, None),
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=['image'],
output_names=['boxes', 'scores', 'labels'],
dynamic_axes={
"image": [0],
"boxes": [0],
"scores": [0],
"labels": [0],
}
)
However, the onnx model that is produced only shows dynamic batch size for the input:
How do I make the output batch size dynamic? Do I need to modify the forward function of the model in some way?
This is just a guess, but are you by any chance processing each input image (or alternatively post-processing detections) of the batch separately inside of a for-loop? If yes, your behaviour might be due to how torch exports to ONNX, and you will need to modify your forward pass. Or, alternatively you can use torch.jit.script
.
Check your model for anything that defines a dimension of a tensor that is interpreted as a python integer during export. Setting dynamic axes will try to use variable shapes for the corresponding tensors, but will be overridden by explicit constant ones.
# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
bs, c, h, w = batch.shape
# bs is saved as a constant integer during export
for i in range(bs):
do_something()
# WRONG - WILL EXPORT WITH STATIC BATCH SIZE
def forward(self, batch):
# iterating over tensors is not supported for dynamic batch sizes
# ONNX model will iterate the same amount as in batch during export
for i in batch:
do_something()
Use tensor.size
instead of tensor.shape
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
def forward(self, batch):
# This calls a function instead of getting an attribute,
# the variable will be dynamic
bs = batch.size(0)
for i in range(bs):
do_something()
Script parts of the model to preserve control flows and different input sizes
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
# Script parts of the forward pass, e.g. single functions
@torch.jit._script_if_tracing
def do_something(batch):
for i in batch:
do_something_else()
def forward(self, batch):
# function will be scripted, dynamic shapes preserved
do_something(batch)
Export the whole module as a ScriptModule
, preserving all control flows
and input sizes
# CORRECT - WILL EXPORT WITH DYNAMIC AXES
script_module = torch.jit.script(model)
torch.onnx.export(
script_module,
...
)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With