Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I define a conditional function using sympy?

I want to be able to define an expression which takes all the values of variable where it is defined and evaluates the expression as 0 when it is not defined. Similar to this: -

    import numpy as np
    import sympy as sp

    def expr(k1, k2):
        x, y =sp.symbols('x y')
        if x == k1 :
            fn = 0
        else:
            fn = np.divide(1,(x-k1)*(y-k2))
        return fn, x, y

    f,x, y = expr(1,2)
    print(f)
    fx = f.subs({x:1,y:4})
    print(fx)

So how is the equality or conditionality going to be checked once the function has been defined?

fn = 1/ (x-1)(y-2); How to set it as 0 for x=1 or y=2?

like image 914
Manish Avatar asked Aug 09 '16 03:08

Manish


2 Answers

If you want a symbolic function, use Piecewise

expr = Piecewise((0, Eq(x, k1)), (1/(x - k1)/(y - k2), True))

If you later want to evaluate this expression on numeric values, you should convert it to a numeric function with lambdify

f = lambdify((x, y, k1, k2), expr, 'numpy')

I do not recommend trying to mix NumPy and SymPy functions, as that generally won't work. NumPy functions don't know how to work with SymPy expressions and SymPy functions don't know how to work with NumPy arrays. The better way is to create the symbolic expression with SymPy, manipulate it however you need, then use lambdify to convert it to a NumPy function.

like image 169
asmeurer Avatar answered Nov 09 '22 15:11

asmeurer


You should define a function inside your function and then return it. Like this:

import numpy as np
import sympy as sp

def expr(k1, k2):
    x, y =sp.symbols('x y')
    def fn(x, y):
        if x==k1:
            return 0
        else:
            return np.divide(1, (x-k1)*(y-k2))
    return fn, x, y


f, x, y = expr(1, 2)
print(f(x, y))
print(f(1, 4))

EDIT:

Here is one way to use sp.lambdify as asked in the comments:

x_dot = 1 / ((x - 1) * (y - 2))
f = lambda a, b : 0 if a==1 or b==2 else sp.lambdify((x,y), xdot, "numpy")(a,b)

Another option is to use sp.subs

f = lambda a, b: 0 if a==1 or b==2 else float(x_dot.subs({x:a, y:b}))
like image 27
Aguy Avatar answered Nov 09 '22 16:11

Aguy