Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

SHA1 in Haskell -- something wrong with my implementation

Tags:

haskell

sha1

Thought I'd try to implement SHA1 in Haskell myself. I came up with an implementation that compiles and returns the right answer for the null string (""), but nothing else. I can't figure out what might be wrong. Can someone familiar with the algorithm and SHA1 point it out?

import Data.Bits
import Data.Int
import Data.List
import Data.Word
import Text.Printf
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Char8 as C

h0 = 0x67452301 :: Word32
h1 = 0xEFCDAB89 :: Word32
h2 = 0x98BADCFE :: Word32
h3 = 0x10325476 :: Word32
h4 = 0xC3D2E1F0 :: Word32

sha1string :: String -> String
sha1string s = concat $ map (printf "%02x") $ sha1 . C.pack $ s 

sha1 :: L.ByteString -> [Word8]
sha1 msg = concat [w32ToComps a, w32ToComps b, w32ToComps c, w32ToComps d, w32ToComps e]
    where (a, b, c, d, e) = sha1' msg 0 h0 h1 h2 h3 h4 

sha1' msg sz a b c d e 
    | L.length m1 < 64 = sha1'last (padded msg sz) a b c d e
    | otherwise        = uncurry5 (sha1' m2 (sz + 64)) $ whole a b c d e m1
    where (m1, m2) = L.splitAt 64 msg

sha1'last msg a b c d e
    | m1 == L.empty = (a, b, c, d, e)
    | otherwise     = uncurry5 (sha1'last m2) $ whole a b c d e m1
    where (m1, m2) = L.splitAt 64 msg

whole a b c d e msg = partcd (partab msg) a b c d e 

partcd ws a b c d e = (h0 + a', h1 + b', h2 + c', h3 + d', h4 + e')
    where
    (a', b', c', d', e')  = go ws a b c d e 0
    go ws a b c d e 80    = (a, b, c, d, e)
    go (w:ws) a b c d e t = go ws temp a (rotate b 30) c d (t+1)
        where temp = (rotate a 5) + f t b c d + e + w + k t

partab chunk = take 80 ns
    where
    ns        = initial ++ zipWith4 g (drop 13 ns) (drop 8 ns) (drop 2 ns) ns
    g a b c d = rotate (a `xor` b `xor` c `xor` d) 1
    initial   = map (L.foldl (\a b -> (a * 256) + fromIntegral b) 0) $ paginate 4 chunk

f t b c d
    | t >=  0 && t <= 19 = (b .&. c) .|. ((complement b) .&. d)
    | t >= 20 && t <= 39 = b `xor` c `xor` d
    | t >= 40 && t <= 59 = (b .&. c) .|. (b .&. d) .|. (c .&. d)
    | t >= 60 && t <= 79 = b `xor` c `xor` d

k t
    | t >=  0 && t <= 19 = 0x5A827999
    | t >= 20 && t <= 39 = 0x6ED9EBA1
    | t >= 40 && t <= 59 = 0x8F1BBCDC
    | t >= 60 && t <= 79 = 0xCA62C1D6

padded msg prevsz = L.append msg (L.pack pad)
    where
    sz      = L.length msg
    totalsz = prevsz + sz
    padsz   = fromIntegral $ (128 - 9 - sz) `mod` 64
    pad     = [0x80] ++ (replicate padsz 0) ++ int64ToComps totalsz

uncurry5 f (a, b, c, d, e) = f a b c d e

paginate n xs
    | xs == L.empty = []
    | otherwise     = let (a, b) = L.splitAt n xs in a : paginate n b

w32ToComps :: Word32 -> [Word8]
w32ToComps = integerToComps [24, 16 .. 0] 

int64ToComps :: Int64 -> [Word8]
int64ToComps = integerToComps [56, 48 .. 0] 

integerToComps :: (Integral a, Bits a) => [Int] -> a -> [Word8]
integerToComps bits x = map f bits
    where f n = fromIntegral ((x `shiftR` n) .&. 0xff) :: Word8
like image 520
Ana Avatar asked Nov 12 '11 01:11

Ana


1 Answers

For starters, you appear to be keeping a size count in bytes (see sz + 64), but the count that gets appended should be in bits so you need to multiply by 8 somewhere (incidentally, I suggest you use cereal or binary instead of rolling your own Integer to big endian Word64). This isn't the only problem though.

EDIT: Found It

Ah-ha! Never forget, wikipedia is written by a bunch of imperative, mutable-world unenlighteneds! You finish each chunk with h0 + a', h1 + b', ... but that should be the old context plus your new values: a + a', b + b', .... Everything checks out after that (and the above size) fix.

The test code completes now with 5 property tests and 129 KATs succeeding.

End Edit

It would help you out a lot if you divided your implementation into the normal initial, update, finalize operations. That way you could compare intermediate results with other implementations.

I just built test code for your implementation using crypto-api-tests. The additional code is below if you're interested, don't forget to install crypto-api-tests.

import Test.SHA
import Test.Crypto
import Crypto.Classes
import Data.Serialize
import Data.Tagged
import Control.Monad

main = defaultMain =<< makeSHA1Tests (undefined :: SHA1)

data SHA1 = SHA1 [Word8]
  deriving (Eq, Ord, Show)
data CTX = CTX L.ByteString
instance Serialize SHA1 where
  get = liftM SHA1 (mapM (const get) [1..20])
  put (SHA1 x) = mapM_ put x

instance Hash CTX SHA1 where
  outputLength = Tagged 160
  blockLength  = Tagged (64*8)
  initialCtx   = CTX L.empty
  updateCtx   (CTX m) x = CTX (L.append m (L.fromChunks [x]))
  finalize  (CTX m) b = SHA1 $ sha1 (L.append m (L.fromChunks [b]))
like image 165
Thomas M. DuBuisson Avatar answered Nov 15 '22 06:11

Thomas M. DuBuisson