Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

fastest way to find the smallest positive real root of quartic polynomial 4 degree in python

[What I want] is to find the only one smallest positive real root of quartic function ax^4 + bx^3 + cx^2 + dx + e

[Existing Method] My equation is for collision prediction, the maximum degree is quartic function as f(x) = ax^4 + bx^3 + cx^2 + dx + e and a,b,c,d,e coef can be positive/negative/zero (real float value). So my function f(x) can be quartic, cubic, or quadratic depending on a, b, c ,d ,e input coefficient.

Currently, I use NumPy to find roots as below.

import numpy

root_output = numpy.roots([a, b, c ,d ,e])

The "root_output" from the NumPy module can be all possible real/complex roots depending on the input coefficient. So I have to look at "root_output" one by one, and check which root is the smallest real positive value (root>0?)

[The Problem] My program needs to execute numpy.roots([a, b, c, d, e]) many times, so many times of executing numpy.roots is too slow for my project. and (a, b, c ,d ,e) value is always changed every time when executing numpy.roots

My attempt is to run the code on Raspberry Pi2. Below is an example of processing time.

  • Running many many times of numpy.roots on PC: 1.2 seconds
  • Running many many times of numpy.roots on Raspberry Pi2: 17 seconds

Could you please guide me on how to find the smallest positive real root in the fastest solution? Using scipy.optimize or implement some algorithm to speed up finding root or any advice from you will be great.

Thank you.

[Solution]

  • Quadratic function only need real positive roots (please be aware of division by zero)
def SolvQuadratic(a, b ,c):
    d = (b**2) - (4*a*c)
    if d < 0:
        return []

    if d > 0:
        square_root_d = math.sqrt(d)
        t1 = (-b + square_root_d) / (2 * a)
        t2 = (-b - square_root_d) / (2 * a)
        if t1 > 0:
            if t2 > 0:
                if t1 < t2:
                    return [t1, t2]
                return [t2, t1]
            return [t1]
        elif t2 > 0:
            return [t2]
        else:
            return []
    else:
        t = -b / (2*a)
        if t > 0:
            return [t]
        return []
  • Quartic Function for quartic function, you can use pure python/numba version as the below answer from @B.M.. I also add another cython version from @B.M's code. You can use the below code as .pyx file and then compile it to get about 2x faster than pure python (please be aware of rounding issues).
import cmath

cdef extern from "complex.h":
    double complex cexp(double complex)

cdef double complex  J=cexp(2j*cmath.pi/3)
cdef double complex  Jc=1/J

cdef Cardano(double a, double b, double c, double d):
    cdef double z0
    cdef double a2, b2
    cdef double p ,q, D
    cdef double complex r
    cdef double complex u, v, w
    cdef double w0, w1, w2
    cdef double complex r1, r2, r3


    z0=b/3/a
    a2,b2 = a*a,b*b
    p=-b2/3/a2 +c/a
    q=(b/27*(2*b2/a2-9*c/a)+d)/a
    D=-4*p*p*p-27*q*q
    r=cmath.sqrt(-D/27+0j)
    u=((-q-r)/2)**0.33333333333333333333333
    v=((-q+r)/2)**0.33333333333333333333333
    w=u*v
    w0=abs(w+p/3)
    w1=abs(w*J+p/3)
    w2=abs(w*Jc+p/3)
    if w0<w1:
      if w2<w0 : v = v*Jc
    elif w2<w1 : v = v*Jc
    else: v = v*J
    r1 = u+v-z0
    r2 = u*J+v*Jc-z0
    r3 = u*Jc+v*J-z0
    return r1, r2, r3

cdef Roots_2(double a, double complex b, double complex c):
    cdef double complex bp
    cdef double complex delta
    cdef double complex r1, r2


    bp=b/2
    delta=bp*bp-a*c
    r1=(-bp-delta**.5)/a
    r2=-r1-b/a
    return r1, r2

def SolveQuartic(double a, double b, double c, double d, double e):
    "Ferrarai's Method"
    "resolution of P=ax^4+bx^3+cx^2+dx+e=0, coeffs reals"
    "First shift : x= z-b/4/a  =>  P=z^4+pz^2+qz+r"
    cdef double z0
    cdef double a2, b2, c2, d2
    cdef double p, q, r
    cdef double A, B, C, D
    cdef double complex y0, y1, y2
    cdef double complex a0, b0
    cdef double complex r0, r1, r2, r3


    z0=b/4.0/a
    a2,b2,c2,d2 = a*a,b*b,c*c,d*d
    p = -3.0*b2/(8*a2)+c/a
    q = b*b2/8.0/a/a2 - 1.0/2*b*c/a2 + d/a
    r = -3.0/256*b2*b2/a2/a2 + c*b2/a2/a/16 - b*d/a2/4+e/a
    "Second find y so P2=Ay^3+By^2+Cy+D=0"
    A=8.0
    B=-4*p
    C=-8*r
    D=4*r*p-q*q
    y0,y1,y2=Cardano(A,B,C,D)
    if abs(y1.imag)<abs(y0.imag): y0=y1
    if abs(y2.imag)<abs(y0.imag): y0=y2
    a0=(-p+2*y0)**.5
    if a0==0 : b0=y0**2-r
    else : b0=-q/2/a0
    r0,r1=Roots_2(1,a0,y0+b0)
    r2,r3=Roots_2(1,-a0,y0-b0)
    return (r0-z0,r1-z0,r2-z0,r3-z0)

[Problem of Ferrari's method] We're facing the problem when the coefficients of quartic equation is [0.00614656, -0.0933333333333, 0.527664995846, -1.31617928376, 1.21906444869] the output from numpy.roots and ferrari methods is entirely different (numpy.roots is correct output).

import numpy as np
import cmath


J=cmath.exp(2j*cmath.pi/3)
Jc=1/J

def ferrari(a,b,c,d,e):
    "Ferrarai's Method"
    "resolution of P=ax^4+bx^3+cx^2+dx+e=0, coeffs reals"
    "First shift : x= z-b/4/a  =>  P=z^4+pz^2+qz+r"
    z0=b/4/a
    a2,b2,c2,d2 = a*a,b*b,c*c,d*d
    p = -3*b2/(8*a2)+c/a
    q = b*b2/8/a/a2 - 1/2*b*c/a2 + d/a
    r = -3/256*b2*b2/a2/a2 +c*b2/a2/a/16-b*d/a2/4+e/a
    "Second find y so P2=Ay^3+By^2+Cy+D=0"
    A=8
    B=-4*p
    C=-8*r
    D=4*r*p-q*q
    y0,y1,y2=Cardano(A,B,C,D)
    if abs(y1.imag)<abs(y0.imag): y0=y1
    if abs(y2.imag)<abs(y0.imag): y0=y2
    a0=(-p+2*y0)**.5
    if a0==0 : b0=y0**2-r
    else : b0=-q/2/a0
    r0,r1=Roots_2(1,a0,y0+b0)
    r2,r3=Roots_2(1,-a0,y0-b0)
    return (r0-z0,r1-z0,r2-z0,r3-z0)

#~ @jit(nopython=True)
def Cardano(a,b,c,d):
    z0=b/3/a
    a2,b2 = a*a,b*b
    p=-b2/3/a2 +c/a
    q=(b/27*(2*b2/a2-9*c/a)+d)/a
    D=-4*p*p*p-27*q*q
    r=cmath.sqrt(-D/27+0j)
    u=((-q-r)/2)**0.33333333333333333333333
    v=((-q+r)/2)**0.33333333333333333333333
    w=u*v
    w0=abs(w+p/3)
    w1=abs(w*J+p/3)
    w2=abs(w*Jc+p/3)
    if w0<w1:
      if w2<w0 : v*=Jc
    elif w2<w1 : v*=Jc
    else: v*=J
    return u+v-z0, u*J+v*Jc-z0, u*Jc+v*J-z0

#~ @jit(nopython=True)
def Roots_2(a,b,c):
    bp=b/2
    delta=bp*bp-a*c
    r1=(-bp-delta**.5)/a
    r2=-r1-b/a
    return r1,r2

coef = [0.00614656, -0.0933333333333, 0.527664995846, -1.31617928376, 1.21906444869]
print("Coefficient A, B, C, D, E", coef) 
print("") 
print("numpy roots: ", np.roots(coef)) 
print("") 
print("ferrari python ", ferrari(*coef))
like image 380
Utthawut Avatar asked Mar 04 '16 12:03

Utthawut


People also ask

What is the minimum number of real distinct roots that a quartic equation can have?

Sample Answer: A quartic function can have 0, 1, 2, 3, or 4 distinct and real roots.

How many real roots can a quartic function have?

If ∆ = 0 then (and only then) the polynomial has a multiple root. Here are the different cases that can occur: If P < 0 and D < 0 and ∆0 ≠ 0, there are a real double root and two real simple roots. If D > 0 or (P > 0 and (D ≠ 0 or R ≠ 0)), there are a real double root and two complex conjugate roots.


1 Answers

An other answer :

do it with analytic methods (Ferrari,Cardan), and speed the code with Just in Time compilation (Numba) :

Let see the improvement first :

In [2]: P=poly1d([1,2,3,4],True)

In [3]: roots(P)
Out[3]: array([ 4.,  3.,  2.,  1.])

In [4]: %timeit roots(P)
1000 loops, best of 3: 465 µs per loop

In [5]: ferrari(*P.coeffs)
Out[5]: ((1+0j), (2-0j), (3+0j), (4-0j))

In [5]: %timeit ferrari(*P.coeffs) #pure python without jit
10000 loops, best of 3: 116 µs per loop    
In [6]: %timeit ferrari(*P.coeffs)  # with numba.jit
100000 loops, best of 3: 13 µs per loop

Then the ugly code :

for order 4 :

@jit(nopython=True)
def ferrari(a,b,c,d,e):
    "resolution of P=ax^4+bx^3+cx^2+dx+e=0"
    "CN all coeffs real."
    "First shift : x= z-b/4/a  =>  P=z^4+pz^2+qz+r"
    z0=b/4/a
    a2,b2,c2,d2 = a*a,b*b,c*c,d*d 
    p = -3*b2/(8*a2)+c/a
    q = b*b2/8/a/a2 - 1/2*b*c/a2 + d/a
    r = -3/256*b2*b2/a2/a2 +c*b2/a2/a/16-b*d/a2/4+e/a
    "Second find X so P2=AX^3+BX^2+C^X+D=0"
    A=8
    B=-4*p
    C=-8*r
    D=4*r*p-q*q
    y0,y1,y2=cardan(A,B,C,D)
    if abs(y1.imag)<abs(y0.imag): y0=y1 
    if abs(y2.imag)<abs(y0.imag): y0=y2 
    a0=(-p+2*y0.real)**.5
    if a0==0 : b0=y0**2-r
    else : b0=-q/2/a0
    r0,r1=roots2(1,a0,y0+b0)
    r2,r3=roots2(1,-a0,y0-b0)
    return (r0-z0,r1-z0,r2-z0,r3-z0) 

for order 3 :

J=exp(2j*pi/3)
Jc=1/J

@jit(nopython=True) 
def cardan(a,b,c,d):
    u=empty(2,complex128)
    z0=b/3/a
    a2,b2 = a*a,b*b    
    p=-b2/3/a2 +c/a
    q=(b/27*(2*b2/a2-9*c/a)+d)/a
    D=-4*p*p*p-27*q*q
    r=sqrt(-D/27+0j)        
    u=((-q-r)/2)**0.33333333333333333333333
    v=((-q+r)/2)**0.33333333333333333333333
    w=u*v
    w0=abs(w+p/3)
    w1=abs(w*J+p/3)
    w2=abs(w*Jc+p/3)
    if w0<w1: 
        if w2<w0 : v*=Jc
    elif w2<w1 : v*=Jc
    else: v*=J        
    return u+v-z0, u*J+v*Jc-z0,u*Jc+v*J-z0

for order 2:

@jit(nopython=True)
def roots2(a,b,c):
    bp=b/2    
    delta=bp*bp-a*c
    u1=(-bp-delta**.5)/a
    u2=-u1-b/a
    return u1,u2  

Probably needs to be test furthermore, but efficient.

like image 164
B. M. Avatar answered Sep 28 '22 14:09

B. M.