Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Handling conditional logic + sentinel value with mypy

I have a function that looks roughly like this:

import datetime
from typing import Union

class Sentinel(object): pass
sentinel = Sentinel()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Sentinel] = sentinel,
) -> str:

    if as_tz is not sentinel:
        # Never reached if as_tz has wrong type (Sentinel)
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

The sentinel value is used here because None is already a valid argument to .astimezone(), so the purpose is to correctly identify cases where the user doesn't want to call .astimezone() at all.

However, mypy complains about this pattern with:

error: Argument 1 to "astimezone" of "datetime" has incompatible type "Union[tzinfo, None, Sentinel]"; expected "Optional[tzinfo]"

It seems that's because the datetime stub (rightfully so) uses:

def astimezone(self, tz: Optional[_tzinfo] = ...) -> datetime: ...

But, is there a way to let mypy know that the sentinel value will never be passed to .astimezone() because of the if check? Or does this just need a # type: ignore with there being no cleaner way?


Another exampe:

from typing import Optional
import requests


def func(session: Optional[requests.Session] = None):
    new_session_made = session is None
    if new_session_made:
        session = requests.Session()
    try:
        session.request("GET", "https://a.b.c.d.com/foo")
        # ...
    finally:
        if new_session_made:
            session.close()

This second, like the first, is "runtime safe" (for lack of a better term): the AttributeError from calling None.request() and None.close() will not be reached or evaluated. However, mypy still complains that:

mypytest.py:9: error: Item "None" of "Optional[Session]" has no attribute "request"
mypytest.py:13: error: Item "None" of "Optional[Session]" has no attribute "close"

Should I be doing something differently here?

like image 833
Brad Solomon Avatar asked Sep 16 '19 15:09

Brad Solomon


2 Answers

In my experience, the best solution is to use enum.Enum.

Requirements

A good sentinel pattern has 3 properties:

  1. possesses an unambiguous type/value that cannot be mistaken for other values. e.g. object()
  2. can be referenced using a descriptive constant
  3. can be tested concisely, using is and is not

Solution

enum.Enum is treated specially by mypy such that it is the only solution I've found that can achieve all three of these requirements and also validate properly within mypy.

import datetime
import enum
from typing import Union

class Sentinel(enum.Enum):
    SKIP_TZ = object()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Sentinel] = Sentinel.SKIP_TZ,
) -> str:

    if as_tz is not Sentinel.SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

Variations

There are a few other interesting properties to this solution.

Reusable Sentinel object

sentinel.py

import enum
class Sentinel(enum.Enum):
    sentinel = object()

main.py

import datetime
from sentinel import Sentinel
from typing import Union

SKIP_TZ = Sentinel.sentinel

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Sentinel] = SKIP_TZ,
) -> str:

    if as_tz is not SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

Note that due to the fact that Sentinel.sentinel always provides the same object instance, two reusable sentinels should never be used the same context.

Restrict Sentinel value using Literal

Replacing Sentinel with Literal[Sentinel.SKIP_TZ]] adds further clarity to your function signature, though it is admittedly superfluous, since there is only one enum value.

import datetime
import enum
from typing import Union
from typing_extensions import Literal

class Sentinel(enum.Enum):
    SKIP_TZ = object()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Literal[Sentinel.SKIP_TZ]] = Sentinel.SKIP_TZ,
) -> str:

    if as_tz is not Sentinel.SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

func(datetime.datetime.now(), as_tz=Sentinel.SKIP_TZ)

Solutions that did not meet my requirements

Custom sentinel class

import datetime
from typing import Union

class SentinelType:
    pass

SKIP_TZ = SentinelType()


def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, SentinelType] = SKIP_TZ,
) -> str:

    if not isinstance(dt, SentinelType):
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

While this works, using isinstance(dt, SentinelType) fails requirement 3 ("use is"), and thus also requirement 2 ("use named constant"). For clarity, I want to be able to use if dt is not SKIP_TZ.

object Literal

Literal does not work with arbitrary values (though it does work with enums. see above.)

import datetime
from typing import Union
from typing_extensions import Literal

SKIP_TZ = object()

def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Literal[SKIP_TZ]] = SKIP_TZ,
) -> str:

    if dt is SKIP_TZ:
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

Produces the following mypy errors:

error: Parameter 1 of Literal[...] is invalid
error: Variable "sentinel.SKIP_TZ" is not valid as a type

String Literal

In this attempt, I used a string literal instead of an object:

import datetime
from typing import Union
from typing_extensions import Literal


def func(
    dt: datetime.datetime,
    as_tz: Union[datetime.tzinfo, None, Literal['SKIP_TZ']] = 'SKIP_TZ',
) -> str:

    if as_tz is not 'SKIP_TZ':
        dt = dt.astimezone(as_tz)
    # ...
    # do other meaningful stuff
    # ...
    return "foo"

func(datetime.datetime.now(), as_tz='SKIP_TZ')

Even if this worked, it would be pretty weak on requirement 1.

But it doesn't pass in mypy. It produces the error:

error: Argument 1 to "astimezone" of "datetime" has incompatible type "Union[tzinfo, None, Literal['SKIP_TZ']]"; expected "Optional[tzinfo]"
like image 162
chadrik Avatar answered Oct 12 '22 11:10

chadrik


You could use an explicit cast:

    from typing import cast
    ... 
    if as_tz is not sentinel:
        # Never reached if as_tz has wrong type (Sentinel)
        as_tz = cast(datetime.tzinfo, as_tz)
        dt = dt.astimezone(as_tz)

and

    new_session_made = session is None
    session = cast(requests.Session, session)

You could alternately use an assert (although this is an actual runtime check whereas the cast is more explicitly a no-op):

        assert isinstance(as_tz, datetime.tzinfo)
        dt = dt.astimezone(as_tz)

and

    new_session_made = session is None
    assert session is not None
like image 39
Andrew Jaffe Avatar answered Oct 12 '22 11:10

Andrew Jaffe