Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Add personalized methods and attributes to CategoricalDtype

Tags:

python

pandas

Is there a way to add personalized methods and attributes to Pandas CategoricalDtype? Should I use a class inheritance or something like ExtensionDtype?

For example:

vehicles = ["Plane", "Rocket", "Car", "Truck"]
vehicle_dtype = CategoricalDtype(categories=vehicles)
s = pd.Series(["Plane", "Plane", "Car"])
s = s.astype(vehicle_dtype)

Is there a solution to add methods and attributes to vehicle_dtype in order to do things like this?

s.cat.is_flying
[True, True, False]

Thank you for your help.

like image 776
Guillaume Avatar asked Feb 28 '19 10:02

Guillaume


1 Answers

s.cat is a pandas.core.arrays.categorical.CategoricalAccessor. If you want s.cat.is_flying to work, you'll need to somehow tell the series to use a subclass accessor that you make instead of the default one. I don't know how to do that, though somebody might. You could instead monkeypatch your accessor on afterwards, but you'd have to do it every time you create a new series, so that seems very brittle and not maintainable. What you can do instead, though, is use a separate custom accessor instead of going through .cat. These actually aren't very difficult to define; see the documentation here. Below is an example that should work for your use case:

import pandas as pd

VehicleDtype = pd.api.types.CategoricalDtype(["Plane", "Rocket", "Car", "Truck"])

@pd.api.extensions.register_series_accessor("vehicle")
class VehicleAccessor:
    def __init__(self, series):
        self._validate(series)
        self._series = series

    @staticmethod
    def _validate(series):
        if not isinstance(series.dtype, CategoricalDtype) or series.dtype != VehicleDtype:
            raise TypeError("Must be VehicleDtype.")

    @property
    def is_flying(self):
        return (self._series == "Plane") | (self._series == "Rocket")

s = pd.Series(["Plane", "Plane", "Car"])
s = s.astype(VehicleDtype)

s
# 0    Plane
# 1    Plane
# 2      Car
# dtype: category
# Categories (4, object): [Plane, Rocket, Car, Truck]

s.vehicle.is_flying
# 0     True
# 1     True
# 2    False
# dtype: bool

For series that aren't of the right type, they'll only throw an error if they try to use the .vehicle accessor:

s2 = pd.Series(list("abcde"))  # works fine
s2.vehicle # TypeError: Must be VehicleDtype.

Note, though, that doing dir(s2) will throw this same error.

There is a similar function to register accessors for dataframes.

like image 66
Nathan Avatar answered Oct 26 '22 23:10

Nathan