oauth2-mock-server/src/AuthCode.hs
2024-01-21 20:35:40 +01:00

182 lines
5.8 KiB
Haskell

-- SPDX-FileCopyrightText: 2024 UniWorX Systems
-- SPDX-FileContributor: David Mosbach <david.mosbach@uniworx.de>
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-}
module AuthCode
( State(..)
, AuthState
, AuthRequest(..)
, JWT(..)
, JWTWrapper(..)
, genUnencryptedCode
, verify
, mkToken
, decodeToken
, renewToken
) where
import User
import Data.Aeson
import Data.ByteString (ByteString (..), fromStrict, toStrict)
import Data.Either (fromRight)
import Data.Map.Strict (Map)
import Data.Maybe (isJust, fromMaybe, fromJust)
import Data.Time.Clock
import Data.Text (pack, replace, Text, stripPrefix)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64
import Data.UUID
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as M
import Control.Concurrent (forkIO, threadDelay)
import Control.Concurrent.STM.TVar
import Control.Monad (void, (>=>))
import Control.Monad.STM
import Jose.Jwa
import Jose.Jwe
import Jose.Jwk (Jwk(..))
import Jose.Jwt hiding (decode, encode)
import Servant.API (FromHttpApiData(..))
data JWT = JWT
{ issuer :: Text
, expiration :: UTCTime
, jti :: UUID
} deriving (Show, Eq)
instance ToJSON JWT where
toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j]
instance FromJSON JWT where
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
data JWTWrapper = JWTW
{ acessToken :: String
, expiresIn :: NominalDiffTime
, refreshToken :: Maybe String
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW a e r) = object
[ "access_token" .= a
, "token_type" .= ("JWT" :: Text)
, "expires_in" .= fromEnum e
, "refresh_token" .= r ]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
<*> o .:? "refresh_token"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
data AuthRequest user = AuthRequest
{ client :: String
, codeExpiration :: NominalDiffTime
, user :: user
, scopes :: [Scope user]
}
data State user = State
{ activeCodes :: Map Text (AuthRequest user)
, activeTokens :: Map UUID (user, [Scope user])
, publicKey :: Jwk
, privateKey :: Jwk
}
type AuthState user = TVar (State user)
genUnencryptedCode :: AuthRequest user
-> String
-> AuthState user
-> IO (Maybe Text)
genUnencryptedCode req url state = do
now <- getCurrentTime
let
expiresAt = req.codeExpiration `addUTCTime` now
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt
success <- atomically . stateTVar state $ \s ->
let mEntry = M.lookup simpleCode s.activeCodes
in
if isJust mEntry
then (False, s)
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
where
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
expire code time state = void . forkIO $ do
threadDelay $ fromEnum time
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
verify code mClientID state = do
now <- getCurrentTime
mData <- atomically $ do
result <- (readTVar >=> return . M.lookup code . activeCodes) state
modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
return result
return $ case mData of
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
_ -> Nothing
mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper
mkToken u scopes state = do
pubKey <- atomically $ readTVar state >>= return . publicKey
now <- getCurrentTime
uuid <- nextRandom
let
lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable
lifetimeRT = nominalDay -- TODO make configurable
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
case encodedAT >> encodedRT of
Right _ -> do
let Jwt aToken = fromRight undefined encodedAT
Jwt rToken = fromRight undefined encodedRT
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
Left e -> error $ show e
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper)
renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of
Just t -> decodeToken t state >>= \case
Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body
now <- getCurrentTime
if now <= expiration jwt then return Nothing else do
mUser <- atomically . stateTVar state $ \s ->
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
in (key, s { activeTokens = tokens })
case mUser of
Just (u, scopes) -> Just <$> mkToken u scopes state
Nothing -> return Nothing
Left _ -> return Nothing
Nothing -> return Nothing