cryptonite/Crypto/Internal/ByteArray.hs
Vincent Hanquez ec4e0c4ed9 remove all the byteArray prefix from byteArray function.
instead expect module import to be qualified for functions.
2015-04-24 06:54:33 +01:00

253 lines
7.6 KiB
Haskell

-- |
-- Module : Crypto.Internal.ByteArray
-- License : BSD-style
-- Maintainer : Vincent Hanquez <vincent@snarc.org>
-- Stability : stable
-- Portability : Good
--
-- Simple and efficient byte array types
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE NoImplicitPrelude #-}
module Crypto.Internal.ByteArray
(
ByteArray(..)
, ByteArrayAccess(..)
-- * Inhabitants
, Bytes
, SecureBytes
-- * methods
, alloc
, allocAndFreeze
, empty
, zero
, copy
, convert
, copyRet
, copyAndFreeze
, split
, xor
, eq
, index
, constEq
, concat
, toBS
, fromBS
, toW64BE
, toW64LE
, mapAsWord64
, mapAsWord128
) where
import Data.SecureMem
import Crypto.Internal.Memory
import Crypto.Internal.Compat
import Crypto.Internal.Endian
import Crypto.Internal.Bytes (bufXor, bufCopy, bufSet)
import Crypto.Internal.Words
import Crypto.Internal.Imports hiding (empty)
import Foreign.Ptr
import Foreign.Storable
import Foreign.ForeignPtr
import Data.ByteString (ByteString)
import qualified Data.ByteString as B (length)
import qualified Data.ByteString.Internal as B
import Prelude (flip, return, div, (-), ($), (==), (/=), (<=), (>=), Int, Bool(..), IO, otherwise, sum, map, fmap, snd, (.), min)
class ByteArrayAccess ba where
length :: ba -> Int
withByteArray :: ba -> (Ptr p -> IO a) -> IO a
class ByteArrayAccess ba => ByteArray ba where
allocRet :: Int -> (Ptr p -> IO a) -> IO (a, ba)
alloc :: ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
alloc n f = snd `fmap` allocRet n f
instance ByteArrayAccess Bytes where
length = bytesLength
withByteArray = withBytes
instance ByteArray Bytes where
allocRet = bytesAllocRet
instance ByteArrayAccess ByteString where
length = B.length
withByteArray b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off)
where (fptr, off, _) = B.toForeignPtr b
instance ByteArray ByteString where
allocRet sz f = do
fptr <- B.mallocByteString sz
r <- withForeignPtr fptr (f . castPtr)
return (r, B.PS fptr 0 sz)
instance ByteArrayAccess SecureMem where
length = secureMemGetSize
withByteArray b f = withSecureMemPtr b (f . castPtr)
instance ByteArray SecureMem where
allocRet sz f = do
out <- allocateSecureMem sz
r <- withSecureMemPtr out (f . castPtr)
return (r, out)
allocAndFreeze :: ByteArray a => Int -> (Ptr p -> IO ()) -> a
allocAndFreeze sz f = unsafeDoIO (alloc sz f)
empty :: ByteArray a => a
empty = unsafeDoIO (alloc 0 $ \_ -> return ())
-- | Create a xor of bytes between a and b.
--
-- the returns byte array is the size of the smallest input.
xor :: (ByteArrayAccess a, ByteArrayAccess b, ByteArray c) => a -> b -> c
xor a b =
allocAndFreeze n $ \pc ->
withByteArray a $ \pa ->
withByteArray b $ \pb ->
bufXor pc pa pb n
where
n = min la lb
la = length a
lb = length b
index :: ByteArrayAccess a => a -> Int -> Word8
index b i = unsafeDoIO $ withByteArray b $ \p -> peek (p `plusPtr` i)
split :: ByteArray bs => Int -> bs -> (bs, bs)
split n bs
| n <= 0 = (empty, bs)
| n >= len = (bs, empty)
| otherwise = unsafeDoIO $ do
withByteArray bs $ \p -> do
b1 <- alloc n $ \r -> bufCopy r p n
b2 <- alloc (len - n) $ \r -> bufCopy r (p `plusPtr` n) (len - n)
return (b1, b2)
where len = length bs
concat :: ByteArray bs => [bs] -> bs
concat [] = empty
concat allBs = allocAndFreeze total (loop allBs)
where
total = sum $ map length allBs
loop [] _ = return ()
loop (b:bs) dst = do
let sz = length b
withByteArray b $ \p -> bufCopy dst p sz
loop bs (dst `plusPtr` sz)
copy :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> IO bs2
copy bs f =
alloc (length bs) $ \d -> do
withByteArray bs $ \s -> bufCopy d s (length bs)
f (castPtr d)
copyRet :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO a) -> IO (a, bs2)
copyRet bs f =
allocRet (length bs) $ \d -> do
withByteArray bs $ \s -> bufCopy d s (length bs)
f (castPtr d)
copyAndFreeze :: (ByteArrayAccess bs1, ByteArray bs2) => bs1 -> (Ptr p -> IO ()) -> bs2
copyAndFreeze bs f =
allocAndFreeze (length bs) $ \d -> do
withByteArray bs $ \s -> bufCopy d s (length bs)
f (castPtr d)
zero :: ByteArray ba => Int -> ba
zero n = allocAndFreeze n $ \ptr -> bufSet ptr 0 n
eq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool
eq b1 b2
| l1 /= l2 = False
| otherwise = unsafeDoIO $
withByteArray b1 $ \p1 ->
withByteArray b2 $ \p2 ->
loop l1 p1 p2
where
l1 = length b1
l2 = length b2
loop :: Int -> Ptr Word8 -> Ptr Word8 -> IO Bool
loop 0 _ _ = return True
loop i p1 p2 = do
e <- (==) <$> peek p1 <*> peek p2
if e then loop (i-1) (p1 `plusPtr` 1) (p2 `plusPtr` 1) else return False
-- | A constant time equality test for 2 ByteArrayAccess values.
--
-- If values are of 2 different sizes, the function will abort early
-- without comparing any bytes.
--
-- compared to == , this function will go over all the bytes
-- present before yielding a result even when knowing the
-- overall result early in the processing.
constEq :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> Bool
constEq b1 b2
| l1 /= l2 = False
| otherwise = unsafeDoIO $
withByteArray b1 $ \p1 ->
withByteArray b2 $ \p2 ->
loop l1 True p1 p2
where
l1 = length b1
l2 = length b2
loop :: Int -> Bool -> Ptr Word8 -> Ptr Word8 -> IO Bool
loop 0 !ret _ _ = return ret
loop i !ret p1 p2 = do
e <- (==) <$> peek p1 <*> peek p2
loop (i-1) (ret &&! e) (p1 `plusPtr` 1) (p2 `plusPtr` 1)
-- Bool == Bool
(&&!) :: Bool -> Bool -> Bool
True &&! True = True
True &&! False = False
False &&! True = False
False &&! False = False
toBS :: ByteArray bs => bs -> ByteString
toBS bs = copyAndFreeze bs (\_ -> return ())
fromBS :: ByteArray bs => ByteString -> bs
fromBS bs = copyAndFreeze bs (\_ -> return ())
toW64BE :: ByteArrayAccess bs => bs -> Int -> Word64
toW64BE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> fromBE64 <$> peek (p `plusPtr` ofs)
toW64LE :: ByteArrayAccess bs => bs -> Int -> Word64
toW64LE bs ofs = unsafeDoIO $ withByteArray bs $ \p -> fromLE64 <$> peek (p `plusPtr` ofs)
mapAsWord128 :: ByteArray bs => (Word128 -> Word128) -> bs -> bs
mapAsWord128 f bs =
allocAndFreeze len $ \dst ->
withByteArray bs $ \src ->
loop (len `div` 16) dst src
where
len = length bs
loop 0 _ _ = return ()
loop i d s = do
w1 <- peek s
w2 <- peek (s `plusPtr` 8)
let (Word128 r1 r2) = f (Word128 (fromBE64 w1) (fromBE64 w2))
poke d (toBE64 r1)
poke (d `plusPtr` 8) (toBE64 r2)
loop (i-1) (d `plusPtr` 16) (s `plusPtr` 16)
mapAsWord64 :: ByteArray bs => (Word64 -> Word64) -> bs -> bs
mapAsWord64 f bs =
allocAndFreeze len $ \dst ->
withByteArray bs $ \src ->
loop (len `div` 8) dst src
where
len = length bs
loop 0 _ _ = return ()
loop i d s = do
w <- peek s
let r = f (fromBE64 w)
poke d (toBE64 r)
loop (i-1) (d `plusPtr` 8) (s `plusPtr` 8)
convert :: (ByteArrayAccess bin, ByteArray bout) => bin -> bout
convert = flip copyAndFreeze (\_ -> return ())