Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Matrix Multiplication in Clojure vs Numpy

I'm working on an application in Clojure that needs to multiply large matrices and am running into some large performance issues compared to an identical Numpy version. Numpy seems to be able to multiply a 1,000,000x23 matrix by its transpose in under a second, while the equivalent clojure code takes over six minutes. (I can print out the resulting matrix from Numpy, so it's definitely evaluating everything).

Am I doing something terribly wrong in this Clojure code? Is there some trick of Numpy that I can try to mimic?

Here's the python:

import numpy as np  def test_my_mult(n):     A = np.random.rand(n*23).reshape(n,23)     At = A.T      t0 = time.time()     res = np.dot(A.T, A)     print time.time() - t0     print np.shape(res)      return res  # Example (returns a 23x23 matrix): # >>> results = test_my_mult(1000000) #  # 0.906938076019 # (23, 23) 

And the clojure:

(defn feature-vec [n]   (map (partial cons 1)        (for [x (range n)]          (take 22 (repeatedly rand)))))  (defn dot-product [x y]   (reduce + (map * x y)))  (defn transpose   "returns the transposition of a `coll` of vectors"   [coll]   (apply map vector coll))  (defn matrix-mult   [mat1 mat2]   (let [row-mult (fn [mat row]                    (map (partial dot-product row)                         (transpose mat)))]     (map (partial row-mult mat2)          mat1)))  (defn test-my-mult   [n afn]   (let [xs  (feature-vec n)         xst (transpose xs)]     (time (dorun (afn xst xs)))))  ;; Example (yields a 23x23 matrix): ;; (test-my-mult 1000 i/mmult) => "Elapsed time: 32.626 msecs" ;; (test-my-mult 10000 i/mmult) => "Elapsed time: 628.841 msecs"  ;; (test-my-mult 1000 matrix-mult) => "Elapsed time: 14.748 msecs" ;; (test-my-mult 10000 matrix-mult) => "Elapsed time: 434.128 msecs" ;; (test-my-mult 1000000 matrix-mult) => "Elapsed time: 375751.999 msecs"   ;; Test from wikipedia ;; (def A [[14 9 3] [2 11 15] [0 12 17] [5 2 3]]) ;; (def B [[12 25] [9 10] [8 5]])  ;; user> (matrix-mult A B) ;; ((273 455) (243 235) (244 205) (102 160)) 

UPDATE: I implemented the same benchmark using the JBLAS library and found massive, massive speed improvements. Thanks to everyone for their input! Time to wrap this sucker in Clojure. Here's the new code:

(import '[org.jblas FloatMatrix])  (defn feature-vec [n]   (FloatMatrix.    (into-array (for [x (range n)]                  (float-array (cons 1 (take 22 (repeatedly rand))))))))  (defn test-mult [n]   (let [xs  (feature-vec n)         xst (.transpose xs)]     (time (let [result (.mmul xst xs)]             [(.rows result)              (.columns result)]))))  ;; user> (test-mult 10000) ;; "Elapsed time: 6.99 msecs" ;; [23 23]  ;; user> (test-mult 100000) ;; "Elapsed time: 43.88 msecs" ;; [23 23]  ;; user> (test-mult 1000000) ;; "Elapsed time: 383.439 msecs" ;; [23 23]  (defn matrix-stream [rows cols]   (repeatedly #(FloatMatrix/randn rows cols)))  (defn square-benchmark   "Times the multiplication of a square matrix."   [n]   (let [[a b c] (matrix-stream n n)]     (time (.mmuli a b c))     nil))  ;; forma.matrix.jblas> (square-benchmark 10) ;; "Elapsed time: 0.113 msecs" ;; nil ;; forma.matrix.jblas> (square-benchmark 100) ;; "Elapsed time: 0.548 msecs" ;; nil ;; forma.matrix.jblas> (square-benchmark 1000) ;; "Elapsed time: 107.555 msecs" ;; nil ;; forma.matrix.jblas> (square-benchmark 2000) ;; "Elapsed time: 793.022 msecs" ;; nil 
like image 545
Sam Ritchie Avatar asked Jan 17 '12 18:01

Sam Ritchie


People also ask

Is NumPy matrix multiplication faster?

Matrix multiplications in NumPy are reasonably fast without the need for optimization. However, if every second counts, it is possible to significantly improve performance (even without a GPU).

What matrix multiplication algorithm does NumPy use?

NumPy uses a highly-optimized, carefully-tuned BLAS method for matrix multiplication (see also: ATLAS). The specific function in this case is GEMM (for generic matrix multiplication).

What is difference between NumPy array and matrix?

Numpy matrices are strictly 2-dimensional, while numpy arrays (ndarrays) are N-dimensional. Matrix objects are a subclass of ndarray, so they inherit all the attributes and methods of ndarrays.

Is NumPy matrix deprecated?

tl; dr: the numpy. matrix class is getting deprecated. There are some high-profile libraries that depend on the class as a dependency (the largest one being scipy.


1 Answers

The Python version is compiling down to a loop in C while the Clojure version is building a new intermediate sequence for each of the calls to map in this code. It is likely that the performance difference you see is coming from the difference of data structures.

To get better than this you could play with a library like Incanter or write your own version as explained in this SO question. see also this one, neanderthal or nd4j. If you really want to stay with sequences to keep the lazy evaluation properties etc. then you may get a real boost by looking into transients for the internal matrix calculations

EDIT: forgot to add the first step in tuning clojure, turn on "warn on reflection"

like image 58
Arthur Ulfeldt Avatar answered Sep 17 '22 19:09

Arthur Ulfeldt