182 lines
5.8 KiB
Haskell
182 lines
5.8 KiB
Haskell
-- SPDX-FileCopyrightText: 2024 UniWorX Systems
|
|
-- SPDX-FileContributor: David Mosbach <david.mosbach@uniworx.de>
|
|
--
|
|
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-}
|
|
|
|
module AuthCode
|
|
( State(..)
|
|
, AuthState
|
|
, AuthRequest(..)
|
|
, JWT(..)
|
|
, JWTWrapper(..)
|
|
, genUnencryptedCode
|
|
, verify
|
|
, mkToken
|
|
, decodeToken
|
|
, renewToken
|
|
) where
|
|
|
|
import User
|
|
|
|
import Data.Aeson
|
|
import Data.ByteString (ByteString (..), fromStrict, toStrict)
|
|
import Data.Either (fromRight)
|
|
import Data.Map.Strict (Map)
|
|
import Data.Maybe (isJust, fromMaybe, fromJust)
|
|
import Data.Time.Clock
|
|
import Data.Text (pack, replace, Text, stripPrefix)
|
|
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
|
import Data.Text.Encoding.Base64
|
|
import Data.UUID
|
|
import Data.UUID.V4
|
|
|
|
import qualified Data.ByteString.Char8 as BS
|
|
import qualified Data.Map.Strict as M
|
|
|
|
import Control.Concurrent (forkIO, threadDelay)
|
|
import Control.Concurrent.STM.TVar
|
|
import Control.Monad (void, (>=>))
|
|
import Control.Monad.STM
|
|
|
|
import Jose.Jwa
|
|
import Jose.Jwe
|
|
import Jose.Jwk (Jwk(..))
|
|
import Jose.Jwt hiding (decode, encode)
|
|
|
|
import Servant.API (FromHttpApiData(..))
|
|
|
|
|
|
data JWT = JWT
|
|
{ issuer :: Text
|
|
, expiration :: UTCTime
|
|
, jti :: UUID
|
|
} deriving (Show, Eq)
|
|
|
|
instance ToJSON JWT where
|
|
toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j]
|
|
|
|
instance FromJSON JWT where
|
|
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
|
|
|
|
|
|
data JWTWrapper = JWTW
|
|
{ acessToken :: String
|
|
, expiresIn :: NominalDiffTime
|
|
, refreshToken :: Maybe String
|
|
} deriving (Show)
|
|
|
|
instance ToJSON JWTWrapper where
|
|
toJSON (JWTW a e r) = object
|
|
[ "access_token" .= a
|
|
, "token_type" .= ("JWT" :: Text)
|
|
, "expires_in" .= fromEnum e
|
|
, "refresh_token" .= r ]
|
|
|
|
instance FromJSON JWTWrapper where
|
|
parseJSON (Object o) = JWTW
|
|
<$> o .: "access_token"
|
|
<*> o .: "expires_in"
|
|
<*> o .:? "refresh_token"
|
|
|
|
instance FromHttpApiData JWTWrapper where
|
|
parseHeader bs = case decode (fromStrict bs) of
|
|
Just x -> Right x
|
|
Nothing -> Left "Invalid JWT wrapper"
|
|
|
|
|
|
data AuthRequest user = AuthRequest
|
|
{ client :: String
|
|
, codeExpiration :: NominalDiffTime
|
|
, user :: user
|
|
, scopes :: [Scope user]
|
|
}
|
|
|
|
|
|
|
|
data State user = State
|
|
{ activeCodes :: Map Text (AuthRequest user)
|
|
, activeTokens :: Map UUID (user, [Scope user])
|
|
, publicKey :: Jwk
|
|
, privateKey :: Jwk
|
|
}
|
|
|
|
type AuthState user = TVar (State user)
|
|
|
|
genUnencryptedCode :: AuthRequest user
|
|
-> String
|
|
-> AuthState user
|
|
-> IO (Maybe Text)
|
|
genUnencryptedCode req url state = do
|
|
now <- getCurrentTime
|
|
let
|
|
expiresAt = req.codeExpiration `addUTCTime` now
|
|
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt
|
|
success <- atomically . stateTVar state $ \s ->
|
|
let mEntry = M.lookup simpleCode s.activeCodes
|
|
in
|
|
if isJust mEntry
|
|
then (False, s)
|
|
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
|
|
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
|
|
where
|
|
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
|
|
expire code time state = void . forkIO $ do
|
|
threadDelay $ fromEnum time
|
|
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
|
|
|
|
|
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
|
|
verify code mClientID state = do
|
|
now <- getCurrentTime
|
|
mData <- atomically $ do
|
|
result <- (readTVar >=> return . M.lookup code . activeCodes) state
|
|
modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
|
return result
|
|
return $ case mData of
|
|
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
|
|
_ -> Nothing
|
|
|
|
|
|
mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper
|
|
mkToken u scopes state = do
|
|
pubKey <- atomically $ readTVar state >>= return . publicKey
|
|
now <- getCurrentTime
|
|
uuid <- nextRandom
|
|
let
|
|
lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable
|
|
lifetimeRT = nominalDay -- TODO make configurable
|
|
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
|
|
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
|
|
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
|
|
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
|
|
case encodedAT >> encodedRT of
|
|
Right _ -> do
|
|
let Jwt aToken = fromRight undefined encodedAT
|
|
Jwt rToken = fromRight undefined encodedRT
|
|
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) }
|
|
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
|
|
Left e -> error $ show e
|
|
|
|
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
|
|
decodeToken token state = do
|
|
prKey <- atomically $ readTVar state >>= return . privateKey
|
|
jwkDecode prKey $ encodeUtf8 token
|
|
|
|
renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper)
|
|
renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of
|
|
Just t -> decodeToken t state >>= \case
|
|
Right (Jwe (header, body)) -> do
|
|
let jwt = fromJust . decode @JWT $ fromStrict body
|
|
now <- getCurrentTime
|
|
if now <= expiration jwt then return Nothing else do
|
|
mUser <- atomically . stateTVar state $ \s ->
|
|
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
|
|
in (key, s { activeTokens = tokens })
|
|
case mUser of
|
|
Just (u, scopes) -> Just <$> mkToken u scopes state
|
|
Nothing -> return Nothing
|
|
Left _ -> return Nothing
|
|
Nothing -> return Nothing
|