254 lines
8.0 KiB
Haskell
254 lines
8.0 KiB
Haskell
-- SPDX-FileCopyrightText: 2024 UniWorX Systems
|
|
-- SPDX-FileContributor: David Mosbach <david.mosbach@uniworx.de>
|
|
--
|
|
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase, DeriveGeneric, AllowAmbiguousTypes #-}
|
|
|
|
module AuthCode
|
|
( State(..)
|
|
, AuthState
|
|
, AuthRequest(..)
|
|
, TokenParams(..)
|
|
, JWT(..)
|
|
, JWTWrapper(..)
|
|
, genUnencryptedCode
|
|
, verify
|
|
, mkToken
|
|
, decodeToken
|
|
, renewToken
|
|
) where
|
|
|
|
import Prelude hiding (exp)
|
|
|
|
import User
|
|
|
|
import Data.Aeson
|
|
import Data.Bool (bool)
|
|
import Data.ByteString (ByteString (..), fromStrict, toStrict)
|
|
import Data.Either (fromRight)
|
|
import Data.List ((\\))
|
|
import Data.Map.Strict (Map)
|
|
import Data.Maybe (isJust, fromMaybe, fromJust, catMaybes)
|
|
import Data.Time.Calendar
|
|
import Data.Time.Clock
|
|
import Data.Text (pack, replace, Text, stripPrefix)
|
|
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
|
import Data.Text.Encoding.Base64
|
|
import Data.UUID hiding (null)
|
|
import Data.UUID.V4
|
|
|
|
import qualified Data.ByteString.Char8 as BS
|
|
import qualified Data.Map.Strict as M
|
|
|
|
import Control.Concurrent (forkIO, threadDelay)
|
|
import Control.Concurrent.STM.TVar
|
|
import Control.Monad (void, (>=>))
|
|
import Control.Monad.STM
|
|
|
|
import GHC.Generics
|
|
|
|
import Jose.Jwa
|
|
import Jose.Jwe
|
|
import Jose.Jwk (Jwk(..))
|
|
import Jose.Jwt hiding (decode, encode)
|
|
import qualified Jose.Jws as Jws
|
|
|
|
import Servant.API (FromHttpApiData(..))
|
|
|
|
import System.Environment (getEnv)
|
|
|
|
|
|
--------------
|
|
---- Tokens ----
|
|
--------------
|
|
|
|
data JWT = JWT
|
|
{ issuer :: Text
|
|
, expiration :: UTCTime
|
|
, jti :: UUID
|
|
} deriving (Show, Eq)
|
|
|
|
instance ToJSON JWT where
|
|
toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j]
|
|
|
|
instance FromJSON JWT where
|
|
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
|
|
|
|
data IDToken = IDT
|
|
{ iss :: Text
|
|
, sub :: Text
|
|
, aud :: [Text]
|
|
, exp :: NominalDiffTime
|
|
, iat :: NominalDiffTime
|
|
, auth_time :: Maybe NominalDiffTime
|
|
, nonce :: Maybe Text
|
|
} deriving (Generic, Show)
|
|
|
|
instance ToJSON IDToken
|
|
instance FromJSON IDToken
|
|
|
|
|
|
data JWTWrapper = JWTW
|
|
{ acessToken :: String
|
|
, expiresIn :: NominalDiffTime
|
|
, refreshToken :: Maybe String
|
|
, idToken :: Maybe String
|
|
} deriving (Show)
|
|
|
|
instance ToJSON JWTWrapper where
|
|
toJSON (JWTW a e r i) = object
|
|
[ "access_token" .= a
|
|
, "token_type" .= ("JWT" :: Text)
|
|
, "expires_in" .= fromEnum e
|
|
, "refresh_token" .= r
|
|
, "id_token" .= i ]
|
|
|
|
instance FromJSON JWTWrapper where
|
|
parseJSON (Object o) = JWTW
|
|
<$> o .: "access_token"
|
|
<*> o .: "expires_in"
|
|
<*> o .:? "refresh_token"
|
|
<*> o .:? "id_token"
|
|
|
|
instance FromHttpApiData JWTWrapper where
|
|
parseHeader bs = case decode (fromStrict bs) of
|
|
Just x -> Right x
|
|
Nothing -> Left "Invalid JWT wrapper"
|
|
|
|
|
|
-------------
|
|
---- State ----
|
|
-------------
|
|
|
|
data AuthRequest user = AuthRequest
|
|
{ client :: String
|
|
, codeExpiration :: NominalDiffTime
|
|
, user :: user
|
|
, scopes :: [Scope' user]
|
|
, rNonce :: Maybe Text
|
|
}
|
|
|
|
type TokenParams user = (user, [Scope' user], Maybe Text)
|
|
|
|
data State user = State
|
|
{ activeCodes :: Map Text (AuthRequest user)
|
|
, activeTokens :: Map UUID (TokenParams user)
|
|
, publicKey :: Jwk
|
|
, privateKey :: Jwk
|
|
}
|
|
|
|
type AuthState user = TVar (State user)
|
|
|
|
-----------------
|
|
---- Functions ----
|
|
-----------------
|
|
|
|
genUnencryptedCode :: AuthRequest user
|
|
-> String
|
|
-> AuthState user
|
|
-> IO (Maybe Text)
|
|
genUnencryptedCode req url state = do
|
|
now <- getCurrentTime
|
|
let
|
|
expiresAt = req.codeExpiration `addUTCTime` now
|
|
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt
|
|
success <- atomically . stateTVar state $ \s ->
|
|
let mEntry = M.lookup simpleCode s.activeCodes
|
|
in
|
|
if isJust mEntry
|
|
then (False, s)
|
|
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
|
|
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
|
|
where
|
|
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
|
|
expire code time state = void . forkIO $ do
|
|
threadDelay $ fromEnum time
|
|
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
|
|
|
|
|
verify :: Text
|
|
-> Maybe String
|
|
-> AuthState user
|
|
-> IO (Maybe (TokenParams user))
|
|
verify code mClientID state = do
|
|
now <- getCurrentTime
|
|
mData <- atomically $ do
|
|
result <- (readTVar >=> return . M.lookup code . activeCodes) state
|
|
modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
|
return result
|
|
return $ case mData of
|
|
Just (AuthRequest clientID' _ u s n) -> if (fromMaybe clientID' mClientID) == clientID'
|
|
then Just (u, s, n)
|
|
else Nothing
|
|
_ -> Nothing
|
|
|
|
|
|
mkToken :: forall user userData . UserData user userData
|
|
=> TokenParams user
|
|
-> Maybe Text -- client_id
|
|
-> AuthState user
|
|
-> IO JWTWrapper
|
|
mkToken (u, scopes, nonce) clientID state = do
|
|
(pubKey, privKey) <- atomically $ readTVar state >>= return . ((,) <$> publicKey <*> privateKey)
|
|
now <- getCurrentTime
|
|
uuid <- nextRandom
|
|
port <- pack <$> getEnv "OAUTH2_SERVER_PORT"
|
|
let
|
|
lifetimeAT = 3600 :: NominalDiffTime -- TODO make configurable
|
|
lifetimeRT = nominalDay -- TODO make configurable
|
|
lifetimeIT = 3600 :: NominalDiffTime -- TODO make configurable
|
|
itRefDate = UTCTime (fromGregorian 1970 1 1) 0
|
|
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
|
|
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
|
|
it = IDT
|
|
{ iss = "http://localhost:" <> port -- TODO maybe make configurable
|
|
, sub = pack . show $ userID @user @userData u
|
|
, aud = catMaybes [clientID]
|
|
, exp = (lifetimeIT `addUTCTime` now) `diffUTCTime` itRefDate
|
|
, iat = now `diffUTCTime` itRefDate
|
|
, auth_time = Just $ now `diffUTCTime` itRefDate
|
|
, nonce = nonce
|
|
}
|
|
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
|
|
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
|
|
encodedIT <- Jws.jwkEncode RS256 privKey (Nested . Jwt . toStrict $ encode it)
|
|
case encodedAT >> encodedRT >> encodedIT of
|
|
Right _ -> do
|
|
let Jwt aToken = fromRight undefined encodedAT
|
|
Jwt rToken = fromRight undefined encodedRT
|
|
Jwt iToken = fromRight undefined encodedIT
|
|
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes, nonce) (activeTokens s) }
|
|
return $ JWTW
|
|
{ acessToken = BS.unpack aToken
|
|
, expiresIn = lifetimeAT
|
|
, refreshToken = Just $ BS.unpack rToken
|
|
, idToken = if Left OpenID `elem` scopes then Nothing else Just $ BS.unpack iToken
|
|
}
|
|
Left e -> error $ show e
|
|
|
|
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
|
|
decodeToken token state = do
|
|
key <- atomically $ readTVar state >>= return . privateKey
|
|
jwkDecode key $ encodeUtf8 token
|
|
|
|
renewToken :: forall user userData . UserData user userData
|
|
=> Text -- ^ token
|
|
-> [Scope' user]
|
|
-> Maybe Text -- ^ client_id
|
|
-> AuthState user
|
|
-> IO (Maybe JWTWrapper) -- TODO more descriptive failures
|
|
renewToken t scopes clientID state = decodeToken t state >>= \case
|
|
Right (Jwe (header, body)) -> do
|
|
let jwt = fromJust . decode @JWT $ fromStrict body
|
|
now <- getCurrentTime
|
|
if now >= expiration jwt then return Nothing else do
|
|
mUser <- atomically . stateTVar state $ \s ->
|
|
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
|
|
in (key, s { activeTokens = tokens })
|
|
case mUser of
|
|
Just (u, scopes', nonce) -> bool (pure Nothing) (Just <$> mkToken @user @userData (u, scopes, nonce) clientID state) (null $ scopes \\ scopes')
|
|
Nothing -> return Nothing
|
|
Left _ -> return Nothing
|
|
|