From 38e831c9b12b85a13d3bad92cf8777ea7fbae91e Mon Sep 17 00:00:00 2001 From: David Mosbach Date: Sun, 21 Jan 2024 20:35:40 +0100 Subject: [PATCH] refresh token can be exchanged at token endpoint --- src/AuthCode.hs | 98 ++++++++++++++++++++++++++++++++++++++++++++----- src/Server.hs | 79 +++++++-------------------------------- 2 files changed, 102 insertions(+), 75 deletions(-) diff --git a/src/AuthCode.hs b/src/AuthCode.hs index 5bb5579..676d530 100644 --- a/src/AuthCode.hs +++ b/src/AuthCode.hs @@ -3,27 +3,36 @@ -- -- SPDX-License-Identifier: AGPL-3.0-or-later -{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-} +{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-} module AuthCode ( State(..) , AuthState , AuthRequest(..) , JWT(..) +, JWTWrapper(..) , genUnencryptedCode , verify +, mkToken +, decodeToken +, renewToken ) where import User import Data.Aeson +import Data.ByteString (ByteString (..), fromStrict, toStrict) +import Data.Either (fromRight) import Data.Map.Strict (Map) -import Data.Maybe (isJust, fromMaybe) +import Data.Maybe (isJust, fromMaybe, fromJust) import Data.Time.Clock -import Data.Text (pack, replace, Text) +import Data.Text (pack, replace, Text, stripPrefix) +import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding.Base64 import Data.UUID +import Data.UUID.V4 +import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as M import Control.Concurrent (forkIO, threadDelay) @@ -31,7 +40,12 @@ import Control.Concurrent.STM.TVar import Control.Monad (void, (>=>)) import Control.Monad.STM +import Jose.Jwa +import Jose.Jwe import Jose.Jwk (Jwk(..)) +import Jose.Jwt hiding (decode, encode) + +import Servant.API (FromHttpApiData(..)) data JWT = JWT @@ -47,6 +61,31 @@ instance FromJSON JWT where parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti" +data JWTWrapper = JWTW + { acessToken :: String + , expiresIn :: NominalDiffTime + , refreshToken :: Maybe String + } deriving (Show) + +instance ToJSON JWTWrapper where + toJSON (JWTW a e r) = object + [ "access_token" .= a + , "token_type" .= ("JWT" :: Text) + , "expires_in" .= fromEnum e + , "refresh_token" .= r ] + +instance FromJSON JWTWrapper where + parseJSON (Object o) = JWTW + <$> o .: "access_token" + <*> o .: "expires_in" + <*> o .:? "refresh_token" + +instance FromHttpApiData JWTWrapper where + parseHeader bs = case decode (fromStrict bs) of + Just x -> Right x + Nothing -> Left "Invalid JWT wrapper" + + data AuthRequest user = AuthRequest { client :: String , codeExpiration :: NominalDiffTime @@ -81,12 +120,11 @@ genUnencryptedCode req url state = do then (False, s) else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes }) if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing - - -expire :: Text -> NominalDiffTime -> AuthState user -> IO () -expire code time state = void . forkIO $ do - threadDelay $ fromEnum time - atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } + where + expire :: Text -> NominalDiffTime -> AuthState user -> IO () + expire code time state = void . forkIO $ do + threadDelay $ fromEnum time + atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user])) @@ -99,3 +137,45 @@ verify code mClientID state = do return $ case mData of Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing _ -> Nothing + + +mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper +mkToken u scopes state = do + pubKey <- atomically $ readTVar state >>= return . publicKey + now <- getCurrentTime + uuid <- nextRandom + let + lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable + lifetimeRT = nominalDay -- TODO make configurable + at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid + rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid + encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at) + encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt) + case encodedAT >> encodedRT of + Right _ -> do + let Jwt aToken = fromRight undefined encodedAT + Jwt rToken = fromRight undefined encodedRT + atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) } + return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken) + Left e -> error $ show e + +decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent) +decodeToken token state = do + prKey <- atomically $ readTVar state >>= return . privateKey + jwkDecode prKey $ encodeUtf8 token + +renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper) +renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of + Just t -> decodeToken t state >>= \case + Right (Jwe (header, body)) -> do + let jwt = fromJust . decode @JWT $ fromStrict body + now <- getCurrentTime + if now <= expiration jwt then return Nothing else do + mUser <- atomically . stateTVar state $ \s -> + let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens + in (key, s { activeTokens = tokens }) + case mUser of + Just (u, scopes) -> Just <$> mkToken u scopes state + Nothing -> return Nothing + Left _ -> return Nothing + Nothing -> return Nothing diff --git a/src/Server.hs b/src/Server.hs index b955ff9..922d66a 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -26,26 +26,20 @@ import Control.Monad.Trans.Error (Error(..)) import Control.Monad.Trans.Reader import Data.Aeson -import Data.ByteString (ByteString (..), fromStrict, toStrict) -import Data.Either (fromRight) +import Data.ByteString (fromStrict) import Data.List (find, elemIndex) import Data.Maybe (fromMaybe, fromJust, isJust, isNothing) import Data.String (IsString (..)) import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words) import qualified Data.Text as T -import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding.Base64 import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime) -import Data.UUID.V4 -import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as Map import GHC.Read (readPrec, lexP) -import Jose.Jwa -import Jose.Jwe -import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId) +import Jose.Jwk (generateRsaKeyPair, KeyUse(Enc), KeyId) import Jose.Jwt hiding (decode, encode) import Network.HTTP.Client (newManager, defaultManagerSettings) @@ -191,10 +185,9 @@ codeServer = handleCreds ---------------------- newtype ACode = ACode String deriving (Show) -newtype RToken = RToken String deriving (Show) data ClientData = ClientData --TODO support other flows - { authID :: Either ACode RToken + { authID :: Either ACode JWTWrapper , clientID :: Maybe String , clientSecret :: Maybe String , redirect :: Maybe String @@ -207,7 +200,8 @@ instance FromHttpApiData AuthFlow where instance FromForm ClientData where fromForm f = ClientData - <$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f) <|> (Right . RToken <$> parseUnique "refresh_token" f))) + <$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f) + <|> (Right <$> parseUnique "refresh_token" f))) <*> parseMaybe "client_id" f <*> parseMaybe "client_secret" f <*> parseMaybe "redirect_uri" f @@ -216,30 +210,6 @@ instance Error Text where strMsg = pack -data JWTWrapper = JWTW - { acessToken :: String - , expiresIn :: NominalDiffTime - , refreshToken :: Maybe String - } deriving (Show) - -instance ToJSON JWTWrapper where - toJSON (JWTW a e r) = object - [ "access_token" .= a - , "token_type" .= ("JWT" :: Text) - , "expires_in" .= fromEnum e - , "refresh_token" .= r ] - -instance FromJSON JWTWrapper where - parseJSON (Object o) = JWTW - <$> o .: "access_token" - <*> o .: "expires_in" - <*> o .:? "refresh_token" - -instance FromHttpApiData JWTWrapper where - parseHeader bs = case decode (fromStrict bs) of - Just x -> Right x - Nothing -> Left "Invalid JWT wrapper" - type Token = "token" :> ReqBody '[FormUrlEncoded] ClientData :> Post '[JSON] JWTWrapper @@ -259,32 +229,14 @@ tokenEndpoint = provideToken unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" } -- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} let (user, scopes) = fromJust mUser - token <- asks (mkToken @user @userData user scopes) >>= liftIO + token <- asks (mkToken @user user scopes) >>= liftIO liftIO . putStrLn $ "token: " ++ show token - return token - Right (RToken rToken) -> undefined - - -mkToken :: forall user userData . UserData user userData - => user -> [Scope user] -> AuthState user -> IO JWTWrapper -mkToken u scopes state = do - pubKey <- atomically $ readTVar state >>= return . publicKey - now <- getCurrentTime - uuid <- nextRandom - let - lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable - lifetimeRT = nominalDay -- TODO make configurable - at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid - rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid - encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at) - encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt) - case encodedAT >> encodedRT of - Right _ -> do - let Jwt aToken = fromRight undefined encodedAT - Jwt rToken = fromRight undefined encodedRT - atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) } - return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken) - Left e -> error $ show e + return token + Right jwtw -> do + mToken <- asks (renewToken @user jwtw) >>= liftIO + case mToken of + Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token + Nothing -> throwError $ err500 { errBody = "Invalid refresh token" } ---------------------- @@ -313,7 +265,7 @@ userEndpoint = handleUserData handleUserData jwtw = do let mToken = stripPrefix "Bearer " jwtw unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format" } - token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO + token <- asks (decodeToken @user (fromJust mToken)) >>= liftIO liftIO $ putStrLn "decoded token:" >> print token case token of Left e -> throwError $ err500 { errBody = fromString $ show e } @@ -327,11 +279,6 @@ userEndpoint = handleUserData Nothing -> throwError $ err500 { errBody = "Unknown token" } -decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent) -decodeToken token state = do - prKey <- atomically $ readTVar state >>= return . privateKey - jwkDecode prKey $ encodeUtf8 token - userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData) userListEndpoint = handleUserData where