Imagine I have a base class and two derived classes. I also have a factory method, that returns an object of one of the classes. The problem is, mypy or IntelliJ can't figure out which type the object is. They know it can be both, but not which one exactly. Is there any way I can help mypy/IntelliJ to figure this out WITHOUT putting a type hint next to the conn variable name?
import abc
import enum
import typing
class BaseConnection(abc.ABC):
@abc.abstractmethod
def sql(self, query: str) -> typing.List[typing.Any]:
...
class PostgresConnection(BaseConnection):
def sql(self, query: str) -> typing.List[typing.Any]:
return "This is a postgres result".split()
def only_postgres_things(self):
pass
class MySQLConnection(BaseConnection):
def sql(self, query: str) -> typing.List[typing.Any]:
return "This is a mysql result".split()
def only_mysql_things(self):
pass
class ConnectionType(enum.Enum):
POSTGRES = 1
MYSQL = 2
def connect(conn_type: ConnectionType) -> typing.Union[PostgresConnection, MySQLConnection]:
if conn_type is ConnectionType.POSTGRES:
return PostgresConnection()
if conn_type is ConnectionType.MYSQL:
return MySQLConnection()
conn = connect(ConnectionType.POSTGRES)
conn.only_postgres_things()
Look at how IntelliJ handles this:

As you can see both methods: only_postgres_things and only_mysql_things are suggested when I'd like IntelliJ/mypy to figure it out out of the type I'm passing to the connect function.
Since the purpose of your ConnectionType class is apparently to make your API more readable and user-friendly rather than to use any specific features of Enum, you don't really have to make it an Enum class.
Instead, you can create a regular class with each connection type assigned to a class variable of a user-friendly name, so that you can type the return value of the connect function with a type variable and type the parameter with the type of the type variable. Use a type alias to make the type of the type variable even more readable:
class ConnectionTypes:
POSTGRES = PostgresConnection
MYSQL = MySQLConnection
Connection = typing.TypeVar('Connection', PostgresConnection, MySQLConnection)
# or make it bound to the base class:
# Connection = typing.TypeVar('Connection', bound=BaseConnection)
ConnectionType: typing.TypeAlias = type[Connection]
def connect(type_: ConnectionType) -> Connection:
if type_ is ConnectionType.POSTGRES:
return PostgresConnection()
if type_ is ConnectionType.MYSQL:
return MySQLConnection()

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With