Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python generic type that implements protocol

Objects A, B ... have attribute namespace and I have a function that filters a list of such objects by a certain set of values of namespace attribute:

T = TypeVar('T')


def filter(seq: list[T], namespace_values: set[str]) -> list[T]:
    # Returns a smaller list containing only the items from
    # `seq` whose `namespace` are in `namespace_values`
    ...

This works well, but it allows passing an object of type X that does not have the attribute namespace without any check error.

Then I created a protocol and changed the function in order to use the protocol:


class Namespaced(Protocol):
    namespace: str

def filter(seq: list[Namespaced], namespace_values: set[str]) -> list[Namespaced]:
    # Returns a smaller list containing only the items from
    # `seq` whose `namespace` are in `namespace_values`
    ...

Now I get a check error if I pass a list of X (which is what I wanted), but I lost the generics:


list_of_a: list[A] = [a1, a2, a3]

output = filter(list_of_a, ['ns1', 'ns2'])

# output is list[Namespaced] instead of list[A]

How can I combine the generics and protocol so my function returns a list of type T and also checks that the seq's items implement Namespaced protocol?

I tried the below approach but the T is lost.


def filter(seq: list[Namespaced[T]], namespace_values: set[str]) -> list[T]:
    # Returns a smaller list containing only the items from
    # `seq` whose `namespace` are in `namespace_values`
    ...

Cheers!

like image 971
svecax Avatar asked Dec 20 '21 08:12

svecax


1 Answers

Use a bound type variable with the protocol as the bound. Consider the following module:

(py39) Juans-MacBook-Pro:~ juan$ cat test.py

Which has:

from typing import TypeVar, Protocol
from dataclasses import dataclass

class Namespaced(Protocol):
    namespace: str


T = TypeVar("T", bound="Namespaced")

@dataclass
class Foo:
    namespace: str

@dataclass
class Bar:
    namespace: str
    id: int

def frobnicate(namespaced: list[T]) -> list[T]:
    for x in namespaced:
        print(x.namespace)
    return namespaced

result1 = frobnicate([Foo('foo')])
result2 = frobnicate([Bar('bar', 1)])

reveal_type(result1)
reveal_type(result2)

Then mypy gives:

(py39) Juans-MacBook-Pro:~ juan$ mypy --strict test.py
test.py:27: note: Revealed type is "builtins.list[test.Foo*]"
test.py:28: note: Revealed type is "builtins.list[test.Bar*]"
like image 130
juanpa.arrivillaga Avatar answered Oct 23 '22 14:10

juanpa.arrivillaga