Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to add arguments to a class that extends `Poly` class in sympy?

I am trying to create a class called PolyExt which is an extension of the Poly class in SymPy. And, it has its own __init__ method. But the problem is that when it passes through the __new__ method in the inherited Poly class, the extra arguments that I added to the __init__ method get interpreted as part of the *gens argument. Is there an elegant way around this?

Here is how to reproduce the error:

class PolyExt(Poly):

    def __init__(self, expression, symb1, symb2):
        self.symb1 = symb1
        self.symb2 = symb2
        super(PolyExt, self).__init__(expression)

x = symbols('x')
y = symbols('y')

PolyExt(x+y, [y], [x])

And, the Poly class looks like the following:

class Poly(Expr):
    """Generic class for representing polynomial expressions."""

    __slots__ = ['rep', 'gens']

    is_commutative = True
    is_Poly = True

    def __new__(cls, rep, *gens, **args):
        """Create a new polynomial instance out of something useful. """
        opt = options.build_options(gens, args)

I tried using keyword arguments for symb1, and symb2 but again, the __new__ method of the Poly class gets in the way, just try the code below

class PolyExt(Poly):

    def __init__(self, expression, **kwargs):
        self.symb1 = kwargs['symb1']
        self.symb2 = kwargs['symb2']
        super(PolyExt, self).__init__(expression)

x = symbols('x')
y = symbols('y')

PolyExt(x+y, symb1=[y], symb2=[x])

What can I do to get around this? Do I need to write my own __new__ method to override the one in class PolyExt? Or is it something else I’m missing?

Thank you.

like image 427
makansij Avatar asked Aug 15 '19 18:08

makansij


1 Answers

You'll need to use __new__ as well:

from sympy import *
from sympy.abc import x, y

class PolyExt(Poly):

    def __new__(self, expression, symb1, symb2):
        obj = Poly.__new__(self, expression)
        obj.symb1 = symb1
        obj.symb2 = symb2
        return obj

x = symbols('x')
y = symbols('y')

p = PolyExt(x+y, [y], [x])
print(p.symb1)
like image 124
Oscar Benjamin Avatar answered Oct 17 '22 13:10

Oscar Benjamin