oauth2-mock-server/src/AuthCode.hs
2024-03-03 21:02:57 +00:00

254 lines
8.0 KiB
Haskell

-- SPDX-FileCopyrightText: 2024 UniWorX Systems
-- SPDX-FileContributor: David Mosbach <david.mosbach@uniworx.de>
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase, DeriveGeneric, AllowAmbiguousTypes #-}
module AuthCode
( State(..)
, AuthState
, AuthRequest(..)
, TokenParams(..)
, JWT(..)
, JWTWrapper(..)
, genUnencryptedCode
, verify
, mkToken
, decodeToken
, renewToken
) where
import Prelude hiding (exp)
import User
import Data.Aeson
import Data.Bool (bool)
import Data.ByteString (ByteString (..), fromStrict, toStrict)
import Data.Either (fromRight)
import Data.List ((\\))
import Data.Map.Strict (Map)
import Data.Maybe (isJust, fromMaybe, fromJust, catMaybes)
import Data.Time.Calendar
import Data.Time.Clock
import Data.Text (pack, replace, Text, stripPrefix)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64
import Data.UUID hiding (null)
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as M
import Control.Concurrent (forkIO, threadDelay)
import Control.Concurrent.STM.TVar
import Control.Monad (void, (>=>))
import Control.Monad.STM
import GHC.Generics
import Jose.Jwa
import Jose.Jwe
import Jose.Jwk (Jwk(..))
import Jose.Jwt hiding (decode, encode)
import qualified Jose.Jws as Jws
import Servant.API (FromHttpApiData(..))
import System.Environment (getEnv)
--------------
---- Tokens ----
--------------
data JWT = JWT
{ issuer :: Text
, expiration :: UTCTime
, jti :: UUID
} deriving (Show, Eq)
instance ToJSON JWT where
toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j]
instance FromJSON JWT where
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
data IDToken = IDT
{ iss :: Text
, sub :: Text
, aud :: [Text]
, exp :: NominalDiffTime
, iat :: NominalDiffTime
, auth_time :: Maybe NominalDiffTime
, nonce :: Maybe Text
} deriving (Generic, Show)
instance ToJSON IDToken
instance FromJSON IDToken
data JWTWrapper = JWTW
{ acessToken :: String
, expiresIn :: NominalDiffTime
, refreshToken :: Maybe String
, idToken :: Maybe String
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW a e r i) = object
[ "access_token" .= a
, "token_type" .= ("JWT" :: Text)
, "expires_in" .= fromEnum e
, "refresh_token" .= r
, "id_token" .= i ]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
<*> o .:? "refresh_token"
<*> o .:? "id_token"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
-------------
---- State ----
-------------
data AuthRequest user = AuthRequest
{ client :: String
, codeExpiration :: NominalDiffTime
, user :: user
, scopes :: [Scope' user]
, rNonce :: Maybe Text
}
type TokenParams user = (user, [Scope' user], Maybe Text)
data State user = State
{ activeCodes :: Map Text (AuthRequest user)
, activeTokens :: Map UUID (TokenParams user)
, publicKey :: Jwk
, privateKey :: Jwk
}
type AuthState user = TVar (State user)
-----------------
---- Functions ----
-----------------
genUnencryptedCode :: AuthRequest user
-> String
-> AuthState user
-> IO (Maybe Text)
genUnencryptedCode req url state = do
now <- getCurrentTime
let
expiresAt = req.codeExpiration `addUTCTime` now
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt
success <- atomically . stateTVar state $ \s ->
let mEntry = M.lookup simpleCode s.activeCodes
in
if isJust mEntry
then (False, s)
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
where
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
expire code time state = void . forkIO $ do
threadDelay $ fromEnum time
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
verify :: Text
-> Maybe String
-> AuthState user
-> IO (Maybe (TokenParams user))
verify code mClientID state = do
now <- getCurrentTime
mData <- atomically $ do
result <- (readTVar >=> return . M.lookup code . activeCodes) state
modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
return result
return $ case mData of
Just (AuthRequest clientID' _ u s n) -> if (fromMaybe clientID' mClientID) == clientID'
then Just (u, s, n)
else Nothing
_ -> Nothing
mkToken :: forall user userData . UserData user userData
=> TokenParams user
-> Maybe Text -- client_id
-> AuthState user
-> IO JWTWrapper
mkToken (u, scopes, nonce) clientID state = do
(pubKey, privKey) <- atomically $ readTVar state >>= return . ((,) <$> publicKey <*> privateKey)
now <- getCurrentTime
uuid <- nextRandom
port <- pack <$> getEnv "OAUTH2_SERVER_PORT"
let
lifetimeAT = 3600 :: NominalDiffTime -- TODO make configurable
lifetimeRT = nominalDay -- TODO make configurable
lifetimeIT = 3600 :: NominalDiffTime -- TODO make configurable
itRefDate = UTCTime (fromGregorian 1970 1 1) 0
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
it = IDT
{ iss = "http://localhost:" <> port -- TODO maybe make configurable
, sub = pack . show $ userID @user @userData u
, aud = catMaybes [clientID]
, exp = (lifetimeIT `addUTCTime` now) `diffUTCTime` itRefDate
, iat = now `diffUTCTime` itRefDate
, auth_time = Just $ now `diffUTCTime` itRefDate
, nonce = nonce
}
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
encodedIT <- Jws.jwkEncode RS256 privKey (Nested . Jwt . toStrict $ encode it)
case encodedAT >> encodedRT >> encodedIT of
Right _ -> do
let Jwt aToken = fromRight undefined encodedAT
Jwt rToken = fromRight undefined encodedRT
Jwt iToken = fromRight undefined encodedIT
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes, nonce) (activeTokens s) }
return $ JWTW
{ acessToken = BS.unpack aToken
, expiresIn = lifetimeAT
, refreshToken = Just $ BS.unpack rToken
, idToken = if Left OpenID `elem` scopes then Nothing else Just $ BS.unpack iToken
}
Left e -> error $ show e
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
key <- atomically $ readTVar state >>= return . privateKey
jwkDecode key $ encodeUtf8 token
renewToken :: forall user userData . UserData user userData
=> Text -- ^ token
-> [Scope' user]
-> Maybe Text -- ^ client_id
-> AuthState user
-> IO (Maybe JWTWrapper) -- TODO more descriptive failures
renewToken t scopes clientID state = decodeToken t state >>= \case
Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body
now <- getCurrentTime
if now >= expiration jwt then return Nothing else do
mUser <- atomically . stateTVar state $ \s ->
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
in (key, s { activeTokens = tokens })
case mUser of
Just (u, scopes', nonce) -> bool (pure Nothing) (Just <$> mkToken @user @userData (u, scopes, nonce) clientID state) (null $ scopes \\ scopes')
Nothing -> return Nothing
Left _ -> return Nothing