From 2e0a60f7f737eab2ab8113bdb011bf37aac2deab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Ch=C3=A9ron?= Date: Sun, 23 Feb 2020 09:02:10 +0100 Subject: [PATCH] Use Semigroup API --- Crypto/Internal/Builder.hs | 10 +++++----- Crypto/Internal/Imports.hs | 4 ++++ Crypto/MAC/KMAC.hs | 15 ++++++++------- Crypto/PubKey/EdDSA.hs | 12 ++++++------ 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/Crypto/Internal/Builder.hs b/Crypto/Internal/Builder.hs index d33ebfd..fd5b920 100644 --- a/Crypto/Internal/Builder.hs +++ b/Crypto/Internal/Builder.hs @@ -14,7 +14,6 @@ module Crypto.Internal.Builder ( Builder , buildAndFreeze , builderLength - , (<+>) , byte , bytes , zero @@ -23,16 +22,17 @@ module Crypto.Internal.Builder import Data.ByteArray (ByteArray, ByteArrayAccess) import qualified Data.ByteArray as B import Data.Memory.PtrMethods (memSet) -import Data.Word (Word8) import Foreign.Ptr (Ptr, plusPtr) import Foreign.Storable (poke) +import Crypto.Internal.Imports + data Builder = Builder !Int (Ptr Word8 -> IO ()) -- size and initializer -(<+>) :: Builder -> Builder -> Builder -(Builder s1 f1) <+> (Builder s2 f2) = Builder (s1 + s2) f - where f p = f1 p >> f2 (p `plusPtr` s1) +instance Semigroup Builder where + (Builder s1 f1) <> (Builder s2 f2) = Builder (s1 + s2) f + where f p = f1 p >> f2 (p `plusPtr` s1) builderLength :: Builder -> Int builderLength (Builder s _) = s diff --git a/Crypto/Internal/Imports.hs b/Crypto/Internal/Imports.hs index 4ed44e1..6d551e9 100644 --- a/Crypto/Internal/Imports.hs +++ b/Crypto/Internal/Imports.hs @@ -5,11 +5,15 @@ -- Stability : experimental -- Portability : unknown -- +{-# LANGUAGE CPP #-} module Crypto.Internal.Imports ( module X ) where import Data.Word as X +#if !(MIN_VERSION_base(4,11,0)) +import Data.Semigroup as X (Semigroup(..)) +#endif import Control.Applicative as X import Control.Monad as X (forM, forM_, void) import Control.Arrow as X (first, second) diff --git a/Crypto/MAC/KMAC.hs b/Crypto/MAC/KMAC.hs index def8b98..b7ad88e 100644 --- a/Crypto/MAC/KMAC.hs +++ b/Crypto/MAC/KMAC.hs @@ -28,6 +28,7 @@ import Crypto.Hash.SHAKE (HashSHAKE(..)) import Crypto.Hash.Types (HashAlgorithm(..), Digest(..)) import qualified Crypto.Hash.Types as H import Crypto.Internal.Builder +import Crypto.Internal.Imports import Foreign.Ptr (Ptr) import Data.Bits (shiftR) import Data.ByteArray (ByteArrayAccess) @@ -45,7 +46,7 @@ cshakeInit n s p = H.Context $ B.allocAndFreeze c $ \(ptr :: Ptr (H.Context a)) where c = hashInternalContextSize (undefined :: a) w = hashBlockSize (undefined :: a) - x = encodeString n <+> encodeString s + x = encodeString n <> encodeString s b = buildAndFreeze (bytepad x w) :: B.Bytes cshakeUpdate :: (HashSHAKE a, ByteArrayAccess ba) @@ -75,7 +76,7 @@ cshakeFinalize !c s = -- The Eq instance is constant time. No Show instance is provided, to avoid -- printing by mistake. newtype KMAC a = KMAC { kmacGetDigest :: Digest a } - deriving ByteArrayAccess + deriving (ByteArrayAccess,NFData) instance Eq (KMAC a) where (KMAC b1) == (KMAC b2) = B.constEq b1 b2 @@ -118,26 +119,26 @@ finalize (Context ctx) = KMAC $ cshakeFinalize ctx suffix -- Utilities bytepad :: Builder -> Int -> Builder -bytepad x w = prefix <+> x <+> zero padLen +bytepad x w = prefix <> x <> zero padLen where prefix = leftEncode w padLen = (w - builderLength prefix - builderLength x) `mod` w encodeString :: ByteArrayAccess bin => bin -> Builder -encodeString s = leftEncode (8 * B.length s) <+> bytes s +encodeString s = leftEncode (8 * B.length s) <> bytes s leftEncode :: Int -> Builder -leftEncode x = byte len <+> digits +leftEncode x = byte len <> digits where digits = i2osp x len = fromIntegral (builderLength digits) rightEncode :: Int -> Builder -rightEncode x = digits <+> byte len +rightEncode x = digits <> byte len where digits = i2osp x len = fromIntegral (builderLength digits) i2osp :: Int -> Builder -i2osp i | i >= 256 = i2osp (shiftR i 8) <+> byte (fromIntegral i) +i2osp i | i >= 256 = i2osp (shiftR i 8) <> byte (fromIntegral i) | otherwise = byte (fromIntegral i) diff --git a/Crypto/PubKey/EdDSA.hs b/Crypto/PubKey/EdDSA.hs index 67b733c..95fa7fd 100644 --- a/Crypto/PubKey/EdDSA.hs +++ b/Crypto/PubKey/EdDSA.hs @@ -296,7 +296,7 @@ getK :: forall proxy curve hash ctx msg . => proxy curve -> Bool -> ctx -> PublicKey curve hash -> Bytes -> msg -> Scalar curve getK prx ph ctx (PublicKey pub) bsR msg = let alg = undefined :: hash - digK = hashWithDom prx alg ph ctx (bytes bsR <+> bytes pub) msg + digK = hashWithDom prx alg ph ctx (bytes bsR <> bytes pub) msg in decodeScalarNoErr prx digK encodeSignature :: EllipticCurveEdDSA curve @@ -304,7 +304,7 @@ encodeSignature :: EllipticCurveEdDSA curve -> (Bytes, Point curve, Scalar curve) -> Signature curve hash encodeSignature prx (bsR, _, sS) = Signature $ buildAndFreeze $ - bytes bsR <+> bytes bsS <+> zero len0 + bytes bsR <> bytes bsS <> zero len0 where bsS = encodeScalarLE prx sS :: Bytes len0 = signatureSize prx - B.length bsR - B.length bsS @@ -339,10 +339,10 @@ instance EllipticCurveEdDSA Curve_Edwards25519 where hashWithDom _ alg ph ctx bss | not ph && B.null ctx = digestDomMsg alg bss - | otherwise = digestDomMsg alg (dom <+> bss) - where dom = bytes ("SigEd25519 no Ed25519 collisions" :: ByteString) <+> - byte (if ph then 1 else 0) <+> - byte (fromIntegral $ B.length ctx) <+> + | otherwise = digestDomMsg alg (dom <> bss) + where dom = bytes ("SigEd25519 no Ed25519 collisions" :: ByteString) <> + byte (if ph then 1 else 0) <> + byte (fromIntegral $ B.length ctx) <> bytes ctx pointPublic _ = PublicKey . Edwards25519.pointEncode