Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can I speed up this basic linear algebra code?

I was wondering whether it is possible to optimise the following using Numpy or mathematical trickery.

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
    p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p

where g is a vector of length n, b is an nxn matrix, dt is the number of iterations, and t1 and t2are scalars.

I have quickly ran out of ideas on how to optimise this further, because p is used within the loop, in all three terms of the equation: when added to itself; in the dot product; and in a scalar multiplication.

But maybe there is a different way to represent this function or there are other tricks to improve its efficiency. If possible, I would prefer not to use Cython etc., but I'd be willing to use it if the speed improvements are significant. Thanks in advance, and apologies if the question is out of scope somehow.

Update:

The answers provided so far are more focused on what the values of the input/output could be to avoid unnecessary operations. I have now updated the MWE with proper initialisation values for the variables (I didn't expect the optimisation ideas to come from that side -- apologies). g will be in the range [-1, 1] and b will be in the range [-infinity, infinity]. Approximating the output is not an option because the returned vectors are later given to an evaluation function -- approximation may return the same vector for fairly similar input, so it is not an option.


MWE:

import numpy as np
import timeit

iterations = 10000

setup = """
import numpy as np
n  = 100
g  = np.random.uniform(-1, 1, (n,)) # Updated.
b  = np.random.uniform(-1, 1, (n,n)) # Updated.
dt = 10
t1 = 1
t2 = 1/2

def f1(g, b, dt, t1, t2):
  p = np.copy(g)
  for i in range(dt):
    p += t1*np.tanh(np.dot(p, b)) + t2*p
  return p
"""

functions = [
  """
    p = f1(g, b, dt, t1, t2)
  """
]

if __name__ == '__main__':
  for function in functions:
    print(function)
    print('Time = {}'.format(timeit.timeit(function, setup=setup,
                                           number=iterations)))
like image 238
sudosensei Avatar asked Mar 10 '14 16:03

sudosensei


People also ask

Is linear algebra useful for programming?

Linear programming: The most widely used application of linear algebra is definitely optimization, and the most widely used kind of optimization is linear programming. You can optimize budgets, your diet, and your route to work using linear programming, and this only scratches the surface of the applications.


1 Answers

To get the code running much faster without cython or jit will be very hard, some mathematical trickery may be more the easier approach. It appears to me that if we define a k(g, b) = f1(g, b, n+1, t1, t2)/f1(g, b, n, t1, t2) for n in positive N, the k function should have a limit of t1+t2 (don't have a solid proof yet, just a gut feeling; it may be a special case for E(g)=0 & E(p)=0 also.). For t1=1 and t2=0.5, k() appears to approach the limit fairly quickly, for N>100, it is almost a constant of 1.5.

So I think a numerical approximation approach should be the easiest one.enter image description here

In [81]:

t2=0.5
data=[f1(g, b, i+2, t1, t2)/f1(g, b, i+1, t1, t2) for i in range(1000)]
In [82]:

plt.figure(figsize=(10,5))
plt.plot(data[0], '.-', label='1')
plt.plot(data[4], '.-', label='5')
plt.plot(data[9], '.-', label='10')
plt.plot(data[49], '.-', label='50')
plt.plot(data[99], '.-', label='100')
plt.plot(data[999], '.-', label='1000')
plt.xlim(xmax=120)
plt.legend()
plt.savefig('limit.png')

In [83]:

data[999]
Out[83]:
array([ 1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,  1.5,
        1.5])
like image 178
CT Zhu Avatar answered Oct 03 '22 18:10

CT Zhu