HiveBrain v1.2.0
Get Started
← Back to all entries
patternMinor

Code Review of Haskell PBKDF2

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
codereviewpbkdf2haskell

Problem

Moved from Programmers.SE.

I have written a new version of the PBKDF2 algorithm in Haskell. It passes all of the HMAC-SHA-1 test vectors listed in RFC 6070, but it is not very efficient. How can I improve the code? I plan to upload it to Github (and maybe Hackage) once it is improved.

`{-# LANGUAGE BangPatterns #-}
{- Copyright 2013, G. Ralph Kuntz, MD. All rights reserved. LGPL License. -}

module Crypto where

import Codec.Utils (Octet)
import qualified Data.Binary as B (encode)
import Data.Bits (xor)
import qualified Data.ByteString.Lazy.Char8 as C (pack)
import qualified Data.ByteString.Lazy as L (unpack)
import Data.List (foldl')
import Data.HMAC (hmac_sha1)
import Text.Bytedump (dumpRaw)

-- Calculate the PBKDF2 as a hexadecimal string
pbkdf2
:: ([Octet] -> [Octet] -> [Octet]) -- pseudo random function (HMAC)
-> Int -- hash length in bytes
-> String -- password
-> String -- salt
-> Int -- iterations
-> Int -- derived key length in bytes
-> String
pbkdf2 prf hashLength password salt iterations keyLength =
let
passwordOctets = stringToOctets password
saltOctets = stringToOctets salt
totalBlocks =
ceiling $ (fromIntegral keyLength :: Double) / fromIntegral hashLength
blockIterator message acc =
foldl' (\(a, m) _ ->
let !m' = prf passwordOctets m
in (zipWith xor a m', m')) (acc, message) [1..iterations]
in
dumpRaw $ take keyLength $ foldl' (\acc block ->
acc ++ fst (blockIterator (saltOctets ++ intToOctets block)
(replicate hashLength 0))) [] [1..totalBlocks]
where
intToOctets :: Int -> [Octet]
intToOctets i =
let a = L.unpack . B.encode $ i
in drop (length a - 4) a

stringToOctets :: String -> [Octet]
stringToOctets = L.unpack . C.pack

-- Calculate the PBKDF2 as a hexadecimal string using HMAC and SHA-1
pbkdf2HmacSha1
:: String -- password
-> String -- salt
-> Int -- iterations
-> Int -- derived key length in bytes
->

Solution

The comments on SO give some good hints. The most critical part is obviously blockIterator, so let's make it a separate function:

blockIterator
    :: ([Octet] -> [Octet] -> [Octet])  -- ^ pseudo random function (HMAC)
    -> [Octet]                          -- ^ password octets
    -> Int                              -- ^ iterations
    -> [Octet]                          -- ^ m
    -> [Octet]                          -- ^ a
    -> [Octet]


The implementation using folds on a list is nice and concise, but processing the list adds some unnecessary overhead, and prevents possible optimizations. If we instead rewrite it as a recursive function

blockIterator prf passwordOctets = loop
  where
    loop i m a | i == 0     = a
               | otherwise  = let m' = prf passwordOctets m
                                  a' = zipWith xor a m'
                               in a' `deepseq` 
                                  m' `deepseq`
                                  loop (i - 1) m' a'


it gets somewhat faster. I also added deepseq on the octet lists so that they're fully evaluated at each round. Introducing the recursion inside loop, instead of making the whole blockIterator recursive allows GHC to inline it.

Measuring performance can be done using the criterion package:

import Criterion.Main

-- ...

main = do
    let stdtest n = pbkdf2HmacSha1 "password" "salt" n 20
    defaultMain [ bgroup "stdtest" $
                    map (\n -> bench (show n) (nf stdtest n))
                        [1, 2, 4096, 65536 ]
                ]


However, the far most time-consuming part is still processing lists inside blockIterator. Using unboxed ST arrays would make the code way faster. There would be no memory allocation in the loop, no need to force evaluation. The problem is, hmac_sha1 is implemented using list, so we'd need another optimized, ST-based implementation.

Code Snippets

blockIterator
    :: ([Octet] -> [Octet] -> [Octet])  -- ^ pseudo random function (HMAC)
    -> [Octet]                          -- ^ password octets
    -> Int                              -- ^ iterations
    -> [Octet]                          -- ^ m
    -> [Octet]                          -- ^ a
    -> [Octet]
blockIterator prf passwordOctets = loop
  where
    loop i m a | i == 0     = a
               | otherwise  = let m' = prf passwordOctets m
                                  a' = zipWith xor a m'
                               in a' `deepseq` 
                                  m' `deepseq`
                                  loop (i - 1) m' a'
import Criterion.Main

-- ...

main = do
    let stdtest n = pbkdf2HmacSha1 "password" "salt" n 20
    defaultMain [ bgroup "stdtest" $
                    map (\n -> bench (show n) (nf stdtest n))
                        [1, 2, 4096, 65536 ]
                ]

Context

StackExchange Code Review Q#30733, answer score: 3

Revisions (0)

No revisions yet.