Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optimizing n-queens in Haskell

This code:

{-# LANGUAGE BangPatterns #-}


module Main where

import Data.Bits
import Data.Word
import Control.Monad
import System.CPUTime
import Data.List

-- The Damenproblem.
-- Wiki: https://de.wikipedia.org/wiki/Damenproblem
main :: IO ()
main = do
  start <- getCPUTime
  print $ dame 14
  end <- getCPUTime
  print $ "Needed " ++ (show ((fromIntegral (end - start)) / (10^12))) ++ " Seconds"

type BitState = (Word64, Word64, Word64)

dame :: Int -> Int
dame max = foldl' (+) 0 $ map fn row
  where fn x = recur (max - 2) $ nextState (x, x, x)
        recur !depth !state = foldl' (+) 0 $ flip map row $ getPossible depth (getStateVal state) state
        getPossible depth !stateVal state bit
          | (bit .&. stateVal) > 0 = 0
          | depth == 0 = 1
          | otherwise = recur (depth - 1) (nextState (addBitToState bit state))
        row = take max $ iterate moveLeft 1

getStateVal :: BitState -> Word64
getStateVal (l, r, c) = l .|. r .|. c

addBitToState :: Word64 -> BitState -> BitState
addBitToState l (ol, or, oc) = (ol .|. l, or .|. l, oc .|. l)

nextState :: BitState -> BitState
nextState (l, r, c) = (moveLeft l, moveRight r, c)

moveRight :: Word64 -> Word64
moveRight x = shiftR x 1

moveLeft :: Word64 -> Word64
moveLeft x = shift x 1

needs about 60 seconds to execute. If i enable compiler optimisation with -O2, it takes about 7 seconds. -O1 is faster and takes about 5 seconds. Tested a java version of this code, with for-loops in place of mapped lists, it takes about 1s (!). Been trying my hardest to optimize yet none of the tips i found online helped more than half a second. Please help

Edit: Java version:

public class Queens{
    static int getQueens(){
        int res = 0;
        for (int i = 0; i < N; i++) {
            int pos = 1 << i;
            res += run(pos << 1, pos >> 1, pos, N - 2);
        }
        return res;
    }

    static int run(long diagR, long diagL, long mid, int depth) {
        long valid = mid | diagL | diagR;
        int resBuffer = 0;

        for (int i = 0; i < N; i++) {
            int pos = 1 << i;
            if ((valid & pos) > 0) {
                continue;
            }
            if (depth == 0) {
                resBuffer++;
                continue;
            }
            long n_mid = mid | pos;
            long n_diagL = (diagL >> 1) | (pos >> 1);
            long n_diagR = (diagR << 1) | (pos << 1);

            resBuffer += run(n_diagR, n_diagL, n_mid, depth - 1);
        }
        return resBuffer;
    }
}

Edit: Running on windows with ghc 8.4.1 on an i5 650 with 3.2GHz.

like image 328
steffmaster Avatar asked Jan 21 '26 04:01

steffmaster


1 Answers

Assuming your algorithm is correct (I haven't verified this), I was able to get consistently 900ms (faster than the Java implementation!). -O2 and -O3 were both comparable on my machine.

Notable changes: (EDIT: Most important change: switch from List to Vector) Switched to GHC 8.4.1, used strictness liberally, BitState is now a strict 3-tuple Using Vectors is important to achieve better speed - in my opinion you can't achieve comparable speed with just linked lists, even with fusion. The Unboxed Vector is important because you know the Vector will always be of Word64s or Ints.

{-# LANGUAGE BangPatterns #-}

module Main (main) where

import Data.Bits ((.&.), (.|.), shiftR, shift)
import Data.Vector.Unboxed (Vector)
import qualified Data.Vector.Unboxed as Vector
import Data.Word (Word64)
import Prelude hiding (max, sum)
import System.CPUTime (getCPUTime)

--
-- The Damenproblem.
-- Wiki: https://de.wikipedia.org/wiki/Damenproblem
main :: IO ()
main = do
  start <- getCPUTime
  print $ dame 14
  end <- getCPUTime
  print $ "Needed " ++ (show ((fromIntegral (end - start)) / (10^12))) ++ " Seconds"

data BitState = BitState !Word64 !Word64 !Word64

bmap :: (Word64 -> Word64) -> BitState -> BitState
bmap f (BitState x y z) = BitState (f x) (f y) (f z)
{-# INLINE bmap #-}

bfold :: (Word64 -> Word64 -> Word64) -> BitState -> Word64
bfold f (BitState x y z) = x `f` y `f` z 
{-# INLINE bfold #-}

singleton :: Word64 -> BitState
singleton !x = BitState x x x
{-# INLINE singleton #-}

dame :: Int -> Int
dame !x = sumWith fn row
  where
    fn !x' = recur (x - 2) $ nextState $ singleton x'
    getPossible !depth !stateVal !state !bit
      | (bit .&. stateVal) > 0 = 0
      | depth == 0 = 1
      | otherwise = recur (depth - 1) (nextState (addBitToState bit state))
    recur !depth !state = sumWith (getPossible depth (getStateVal state) state) row
    !row = Vector.iterateN x moveLeft 1

sumWith :: (Vector.Unbox a, Vector.Unbox b, Num b) => (a -> b) -> Vector a -> b
sumWith f as = Vector.sum $ Vector.map f as
{-# INLINE sumWith #-}

getStateVal :: BitState -> Word64
getStateVal !b = bfold (.|.) b

addBitToState :: Word64 -> BitState -> BitState
addBitToState !l !b = bmap (.|. l) b

nextState :: BitState -> BitState
nextState !(BitState l r c) = BitState (moveLeft l) (moveRight r) c

moveRight :: Word64 -> Word64
moveRight !x = shiftR x 1
{-# INLINE moveRight #-}

moveLeft :: Word64 -> Word64
moveLeft !x = shift x 1
{-# INLINE moveLeft #-}

I checked the core with ghc dame.hs -O2 -fforce-recomp -ddump-simpl -dsuppress-all, and it looked pretty good (i.e. everything unboxed, loops looked good). I was concerned that the partial application of getPossible might be a problem, but it turned out to not be. I feel like if I understood the algorithm better it might be possible to write in a better/more efficient way, however I'm not too concerned - this still manages to beat the Java implementation.

like image 188
chessai Avatar answered Jan 23 '26 19:01

chessai



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!