diff --git a/oauth2-mock-server.cabal b/oauth2-mock-server.cabal index eeb6d46..b7b9dc2 100644 --- a/oauth2-mock-server.cabal +++ b/oauth2-mock-server.cabal @@ -32,6 +32,7 @@ library build-depends: aeson , base >=4.7 && <5 + , base64 , bytestring , containers , http-client @@ -58,6 +59,7 @@ executable oauth2-mock-server-exe build-depends: aeson , base >=4.7 && <5 + , base64 , bytestring , containers , http-client @@ -86,6 +88,7 @@ test-suite oauth2-mock-server-test build-depends: aeson , base >=4.7 && <5 + , base64 , bytestring , containers , http-client diff --git a/package.yaml b/package.yaml index 320c558..040fdf5 100644 --- a/package.yaml +++ b/package.yaml @@ -34,6 +34,7 @@ dependencies: - transformers - bytestring - jose-jwt +- base64 ghc-options: - -Wall diff --git a/src/AuthCode.hs b/src/AuthCode.hs index da1dfe1..ed39eaf 100644 --- a/src/AuthCode.hs +++ b/src/AuthCode.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE OverloadedRecordDot #-} +{-# LANGUAGE OverloadedRecordDot, OverloadedStrings #-} module AuthCode ( State (..) @@ -10,6 +10,8 @@ module AuthCode import Data.Map.Strict (Map) import Data.Maybe (isJust) import Data.Time.Clock +import Data.Text (pack, replace, Text) +import Data.Text.Encoding.Base64 import qualified Data.Map.Strict as M @@ -24,7 +26,7 @@ import Jose.Jwk (Jwk(..)) data State = State - { activeCodes :: Map String (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time) + { activeCodes :: Map Text (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time) , publicKey :: Jwk , privateKey :: Jwk } deriving Show @@ -35,12 +37,12 @@ genUnencryptedCode :: String -> String -> NominalDiffTime -> AuthState - -> IO (Maybe String) + -> IO (Maybe Text) genUnencryptedCode client url expiration state = do now <- getCurrentTime let expiresAt = expiration `addUTCTime` now - simpleCode = filter (/= ' ') $ client <> url <> show now <> show expiresAt + simpleCode = replace "=" "" . encodeBase64 . pack . filter (/= ' ') $ client <> url <> show now <> show expiresAt success <- atomically . stateTVar state $ \s -> let mEntry = M.lookup simpleCode s.activeCodes in @@ -50,13 +52,13 @@ genUnencryptedCode client url expiration state = do if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing -expire :: String -> NominalDiffTime -> AuthState -> IO () +expire :: Text -> NominalDiffTime -> AuthState -> IO () expire code time state = void . forkIO $ do threadDelay $ fromEnum time atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } -verify :: String -> String -> AuthState -> IO Bool +verify :: Text -> String -> AuthState -> IO Bool verify code clientID state = do now <- getCurrentTime mData <- atomically $ do diff --git a/src/Server.hs b/src/Server.hs index 52f224d..c8ab1de 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -130,14 +130,14 @@ authServer = handleAuth redirect :: Maybe ByteString -> AuthHandler userData redirect (Just url) = liftIO (print url) >> throwError err303 { errHeaders = [("Location", url)]} redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."} - addParams :: String -> Maybe String -> Maybe String -> Maybe ByteString + addParams :: String -> Maybe Text -> Maybe String -> Maybe ByteString addParams url Nothing _ = Nothing addParams url (Just code) mState = let qPos = fromMaybe (length url) $ elemIndex '?' url (pre, post) = splitAt qPos url rState = case mState of {Just s -> "&state=" ++ s; Nothing -> ""} post' = if not (null post) then '&' : tail post else post - in Just . fromString $ pre ++ "?code=" ++ code ++ post' ++ rState + in Just . fromString $ pre ++ "?code=" ++ (unpack code) ++ post' ++ rState ---------------------- @@ -217,7 +217,7 @@ tokenEndpoint = provideToken AuthCode -> do unless (Client (pack $ clientID client) (pack $ clientSecret client) `elem` trustedClients) . throwError $ err500 { errBody = "Invalid client" } - valid <- asks (verify (grant client) (clientID client)) >>= liftIO + valid <- asks (verify (pack $ grant client) (clientID client)) >>= liftIO unless valid . throwError $ err500 { errBody = "Invalid authorisation code" } -- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} token <- asks mkToken >>= liftIO