store active tokens & scopes

This commit is contained in:
David Mosbach 2024-01-10 21:45:28 +01:00
parent c7989034b4
commit 2716ccf2c0
5 changed files with 98 additions and 64 deletions

View File

@ -20,6 +20,7 @@ extra-source-files:
library
exposed-modules:
AuthCode
DB
Server
User
other-modules:
@ -45,6 +46,7 @@ library
, text
, time
, transformers
, uuid
, warp
default-language: Haskell2010
@ -74,6 +76,7 @@ executable oauth2-mock-server-exe
, text
, time
, transformers
, uuid
, warp
default-language: Haskell2010
@ -104,5 +107,6 @@ test-suite oauth2-mock-server-test
, text
, time
, transformers
, uuid
, warp
default-language: Haskell2010

View File

@ -36,6 +36,7 @@ dependencies:
- jose-jwt
- base64
- http-api-data
- uuid
ghc-options:
- -Wall

View File

@ -1,17 +1,23 @@
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings #-}
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-}
module AuthCode
( State (..)
( State(..)
, AuthState
, AuthRequest(..)
, JWT(..)
, genUnencryptedCode
, verify
) where
import User
import Data.Aeson
import Data.Map.Strict (Map)
import Data.Maybe (isJust, fromMaybe)
import Data.Time.Clock
import Data.Text (pack, replace, Text)
import Data.Text.Encoding.Base64
import Data.UUID
import qualified Data.Map.Strict as M
@ -23,42 +29,62 @@ import Control.Monad.STM
import Jose.Jwk (Jwk(..))
data JWT = JWT
{ issuer :: Text
, expiration :: UTCTime
, jti :: UUID
} deriving (Show, Eq)
instance ToJSON JWT where
toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j]
instance FromJSON JWT where
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
data State = State
{ activeCodes :: Map Text (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time)
data AuthRequest user = AuthRequest
{ client :: String
, codeExpiration :: NominalDiffTime
, user :: user
, scopes :: [Scope user]
}
data State user = State
{ activeCodes :: Map Text (AuthRequest user)
, activeTokens :: Map UUID (user, [Scope user])
, publicKey :: Jwk
, privateKey :: Jwk
} deriving Show
}
type AuthState = TVar State
type AuthState user = TVar (State user)
genUnencryptedCode :: String
genUnencryptedCode :: AuthRequest user
-> String
-> NominalDiffTime
-> AuthState
-> AuthState user
-> IO (Maybe Text)
genUnencryptedCode client url expiration state = do
genUnencryptedCode req url state = do
now <- getCurrentTime
let
expiresAt = expiration `addUTCTime` now
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ client <> url <> show now <> show expiresAt
expiresAt = req.codeExpiration `addUTCTime` now
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt
success <- atomically . stateTVar state $ \s ->
let mEntry = M.lookup simpleCode s.activeCodes
in
if isJust mEntry
then (False, s)
else (True, s{ activeCodes = M.insert simpleCode (client, expiresAt) s.activeCodes })
if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
expire :: Text -> NominalDiffTime -> AuthState -> IO ()
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
expire code time state = void . forkIO $ do
threadDelay $ fromEnum time
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
verify :: Text -> Maybe String -> AuthState -> IO Bool
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
verify code mClientID state = do
now <- getCurrentTime
mData <- atomically $ do
@ -66,5 +92,5 @@ verify code mClientID state = do
modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
return result
return $ case mData of
Just (clientID', _) -> (fromMaybe clientID' mClientID) == clientID'
_ -> False
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
_ -> Nothing

4
src/DB.hs Normal file
View File

@ -0,0 +1,4 @@
module DB where
import User

View File

@ -12,7 +12,7 @@ import User
import Control.Applicative ((<|>))
import Control.Concurrent
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TVar (newTVarIO, readTVar)
import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
import Control.Exception (bracket)
import Control.Monad (unless)
import Control.Monad.IO.Class
@ -26,6 +26,7 @@ import Data.String (IsString (..))
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as Map
@ -103,13 +104,13 @@ type Auth user userData = BasicAuth "login" user
-- type Insert = "insert" :> Post '[JSON] User
type AuthHandler = ReaderT AuthState Handler
type AuthServer a = ServerT a AuthHandler
type AuthHandler user = ReaderT (AuthState user) Handler
type AuthServer user a = ServerT a (AuthHandler user)
toHandler :: AuthState -> AuthHandler a -> Handler a
toHandler :: forall user userData a . UserData user userData => AuthState user -> AuthHandler user a -> Handler a
toHandler s h = runReaderT h s
authServer :: forall user userData . UserData user userData => AuthServer (Auth user userData)
authServer :: forall user userData . UserData user userData => AuthServer user (Auth user userData)
authServer = handleAuth
where
handleAuth :: user
@ -118,18 +119,20 @@ authServer = handleAuth
-> QResType
-> QRedirect
-> Maybe QState
-> AuthHandler userData
-> AuthHandler user userData
handleAuth u scopes client responseType url mState = do
unless (isJust $ find (\c -> ident c == pack client) trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client
throwError $ err404 { errBody = "Not a trusted client."}
let responseType' = read @ResponseType responseType
let
responseType' = read @ResponseType responseType
scopes' = map (readScope @user @userData) $ words scopes
liftIO $ print responseType'
unless (responseType' == Code) $ throwError err500 { errBody = "Unsupported response type" }
mAuthCode <- asks (genUnencryptedCode client url 600) >>= liftIO
mAuthCode <- asks (genUnencryptedCode (AuthRequest client 600 u scopes') url) >>= liftIO
liftIO $ print mAuthCode
-- liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
redirect $ addParams url mAuthCode mState
redirect :: Maybe ByteString -> AuthHandler userData
redirect :: Maybe ByteString -> AuthHandler user userData
redirect (Just url) = liftIO (print url) >> throwError err303 { errHeaders = [("Location", url)]}
redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."}
addParams :: String -> Maybe Text -> Maybe Text -> Maybe ByteString
@ -167,18 +170,6 @@ instance FromForm ClientData where
<*> parseMaybe "redirect_uri" f
data JWT = JWT
{ issuer :: Text
, expiration :: UTCTime
} deriving (Show, Eq)
instance ToJSON JWT where
toJSON (JWT i e) = object ["iss" .= i, "exp" .= e]
instance FromJSON JWT where
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp"
data JWTWrapper = JWTW
{ token :: String
, expiresIn :: NominalDiffTime
@ -202,32 +193,37 @@ type Token = "token"
:> Post '[JSON] JWTWrapper
tokenEndpoint :: AuthServer Token
tokenEndpoint :: forall user userData . UserData user userData => AuthServer user Token
tokenEndpoint = provideToken
where
provideToken :: ClientData -> AuthHandler JWTWrapper
provideToken :: ClientData -> AuthHandler user JWTWrapper
provideToken client = do
unless (isNothing (clientID client >> clientSecret client)
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
throwError $ err500 { errBody = "Invalid client" }
valid <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
unless valid . throwError $ err500 { errBody = "Invalid authorisation code" }
mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
token <- asks mkToken >>= liftIO
let (user, scopes) = fromJust mUser
token <- asks (mkToken @user @userData user scopes) >>= liftIO
liftIO . putStrLn $ "token: " ++ show token
return token
mkToken :: AuthState -> IO JWTWrapper
mkToken state = do
mkToken :: forall user userData . UserData user userData
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
mkToken u scopes state = do
pubKey <- atomically $ readTVar state >>= return . publicKey
now <- getCurrentTime
uuid <- nextRandom
let
lifetime = nominalDay / 4 -- TODO make configurable
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now)
lifetime = nominalDay / 24 -- TODO make configurable
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid
encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt)
case encoded of
Right (Jwt token) -> return $ JWTW (BS.unpack token) lifetime
Right (Jwt token) -> do
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack token) lifetime
Left e -> error $ show e
@ -250,20 +246,21 @@ type UserList userData = Users
:> Get '[JSON] [userData] -- TODO support query params
userEndpoint :: forall user userData . UserData user userData => AuthServer (Me userData)
userEndpoint :: forall user userData . UserData user userData => AuthServer user (Me userData)
userEndpoint = handleUserData
where
handleUserData :: Text -> AuthHandler userData
handleUserData :: Text -> AuthHandler user userData
handleUserData jwtw = do
let mToken = stripPrefix "Bearer " jwtw
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"}
token <- asks (decodeToken (fromJust mToken)) >>= liftIO
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
liftIO $ putStrLn "decoded token:" >> print token
case token of
Left e -> throwError $ err500 { errBody = fromString $ show e }
Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body
-- TODO check if token grants access, then read logged in user from cookie
liftIO $ print jwt
return mempty
-- let
-- scopes' = map (readScope @user @userData) $ words scopes
@ -271,15 +268,15 @@ userEndpoint = handleUserData
-- liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
-- return uData
decodeToken :: Text -> AuthState -> IO (Either JwtError JwtContent)
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
userListEndpoint :: forall user userData . UserData user userData => AuthServer (UserList userData)
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
userListEndpoint = handleUserData
where
handleUserData :: Text -> AuthHandler [userData]
handleUserData :: Text -> AuthHandler user [userData]
handleUserData jwtw = do
undefined
@ -293,13 +290,13 @@ type Routing user userData = Auth user userData
:<|> Me userData
:<|> UserList userData
routing :: forall user userData . UserData user userData => AuthServer (Routing user userData)
routing :: forall user userData . UserData user userData => AuthServer user (Routing user userData)
routing = authServer @user @userData
:<|> tokenEndpoint
:<|> tokenEndpoint @user @userData
:<|> userEndpoint @user @userData
:<|> userListEndpoint @user @userData
exampleAuthServer :: AuthServer (Routing User (Map.Map Text Text))
exampleAuthServer :: AuthServer User (Routing User (Map.Map Text Text))
exampleAuthServer = routing
@ -309,8 +306,8 @@ authAPI = Proxy
-- insecureOAuthMock :: Application
-- insecureOAuthMock = authAPI `serve` exampleAuthServer
insecureOAuthMock' :: [User] -> AuthState -> Application
insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler s) exampleAuthServer
insecureOAuthMock' :: [User] -> AuthState User -> Application
insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler @User @(Map.Map Text Text) s) exampleAuthServer
where
c = authenticate testUsers :. EmptyContext
p = Proxy :: Proxy '[BasicAuthCheck User]
@ -328,7 +325,7 @@ authenticate users = BasicAuthCheck $ \authData -> do
runMockServer :: Int -> IO ()
runMockServer port = do
state <- mkState
state <- mkState @User @(Map.Map Text Text)
run port $ insecureOAuthMock' testUsers state
-- runMockServer' :: Int -> IO ()
@ -339,9 +336,11 @@ runMockServer port = do
-- runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port ""))
-- >>= print
mkState :: IO AuthState
mkState :: forall user userData . UserData user userData => IO (AuthState user)
mkState = do
(publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing
let activeCodes = Map.empty
let
activeCodes = Map.empty
activeTokens = Map.empty
newTVarIO State{..}