Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to avoid repeating kwargs in Python?

Tags:

python

dry

Suppose there is a method:

def train_model(self, out_dir='./out/',
                test_size=0.2, train_size=None,
                random_state=None, shuffle=True, stratify=None,
                epochs=DEFAULT_EPOCHS, batch_size=DEFAULT_BATCH_SIZE):
    ...
    self.model.train(test_size=test_size, train_size=train_size, random_state=random_state, shuffle=shuffle, stratify=stratify, epochs=epochs, batch_size=batch_size)

And inside this function another method with the same signature will be called, then I have to pass all the params manually. I don't want to use kwargs in train_model as it's a public method that may used by others, so I hope to keep the typing information. I don't know if there are methods to allow me to keep the typing information in kwargs of outer function.

The same functionality in TypeScript can be achieved using the Parameters utility types. For example,

function sum(a: int, b: int) {
    return a + b;
}

type SumParamsType = Paramters<typeof sum>

// Then you can use the SumPramsType in other places.

A failed example of Python:

from typing import TypeVar
T = TypeVar('T')

def f1(a=1, b=2, c=3):
    return a+b+c

# Is there anything like T=Parameters(type(f1)) in Python?

def f2(z=0, **kwargs: T):
    return z+f1(**kwargs)

# T cannot capture the kwargs of f1 (of course it won't)

And this doesn't works either:

def f1(a=1, b=2, c=3):
    return a+b+c

def f2(z=0, **kwargs: f1.__annotations__['kwargs']):
    return z + f1(**kwargs)

# kwargs has the type Any
like image 606
link89 Avatar asked May 13 '26 21:05

link89


2 Answers

The closest you can get is to use a TypedDict with Unpack (available on Python < 3.11 via typing_extensions):

from typing_extensions import Unpack, TypedDict, NotRequired


class Params(TypedDict):
    a: NotRequired[int]
    b: NotRequired[int]
    c: NotRequired[int]


def f1(**kwargs: Unpack[Params]):
    a = kwargs.pop('a', 1)
    b = kwargs.pop('b', 1)
    c = kwargs.pop('c', 1)
    return a + b + c


def f2(z=0, **kwargs: Unpack[Params]):
    return z + f1(**kwargs)

Note that your IDE may not support Unpack if it does not use mypy --enable-incomplete-feature=Unpack. VSCode supports it out of the box. PyCharm, probably not.

If you control both function definitions, you may find it easier to change your methods to accept a dataclass encapsulating all the parameters and their defaults instead of taking each parameter individually.

like image 197
MEE Avatar answered May 16 '26 12:05

MEE


You could create a class containing the training arguments and pass it to the train method, as is done in the HuggingFace Transformers library

Here is the code from their GitHub:

from dataclasses import asdict, dataclass, field, fields

#...

@dataclass
class TrainingArguments:
    framework = "pt"
    output_dir: str = field(
        metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={
            "help": (
                "Overwrite the content of the output directory. "
                "Use this to continue training if output_dir points to a checkpoint directory."
            )
        },
    )

    do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
    do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
    do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
    evaluation_strategy: Union[IntervalStrategy, str] = field(
        default="no",
        metadata={"help": "The evaluation strategy to use."},
    )
    prediction_loss_only: bool = field(
        default=False,
        metadata={"help": "When performing evaluation and predictions, only returns the loss."},
    )

    per_device_train_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
    )
    per_device_eval_batch_size: int = field(
        default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
    )
    # ...

It is a bit verbose but very clear and will work with your IDE type hinting.

like image 21
Caridorc Avatar answered May 16 '26 12:05

Caridorc



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!