refresh token can be exchanged at token endpoint

This commit is contained in:
David Mosbach 2024-01-21 20:35:40 +01:00
parent d31c9abffc
commit 38e831c9b1
2 changed files with 102 additions and 75 deletions

View File

@ -3,27 +3,36 @@
--
-- SPDX-License-Identifier: AGPL-3.0-or-later
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-}
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-}
module AuthCode
( State(..)
, AuthState
, AuthRequest(..)
, JWT(..)
, JWTWrapper(..)
, genUnencryptedCode
, verify
, mkToken
, decodeToken
, renewToken
) where
import User
import Data.Aeson
import Data.ByteString (ByteString (..), fromStrict, toStrict)
import Data.Either (fromRight)
import Data.Map.Strict (Map)
import Data.Maybe (isJust, fromMaybe)
import Data.Maybe (isJust, fromMaybe, fromJust)
import Data.Time.Clock
import Data.Text (pack, replace, Text)
import Data.Text (pack, replace, Text, stripPrefix)
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64
import Data.UUID
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as M
import Control.Concurrent (forkIO, threadDelay)
@ -31,7 +40,12 @@ import Control.Concurrent.STM.TVar
import Control.Monad (void, (>=>))
import Control.Monad.STM
import Jose.Jwa
import Jose.Jwe
import Jose.Jwk (Jwk(..))
import Jose.Jwt hiding (decode, encode)
import Servant.API (FromHttpApiData(..))
data JWT = JWT
@ -47,6 +61,31 @@ instance FromJSON JWT where
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
data JWTWrapper = JWTW
{ acessToken :: String
, expiresIn :: NominalDiffTime
, refreshToken :: Maybe String
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW a e r) = object
[ "access_token" .= a
, "token_type" .= ("JWT" :: Text)
, "expires_in" .= fromEnum e
, "refresh_token" .= r ]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
<*> o .:? "refresh_token"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
data AuthRequest user = AuthRequest
{ client :: String
, codeExpiration :: NominalDiffTime
@ -81,12 +120,11 @@ genUnencryptedCode req url state = do
then (False, s)
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
expire code time state = void . forkIO $ do
threadDelay $ fromEnum time
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
where
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
expire code time state = void . forkIO $ do
threadDelay $ fromEnum time
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
@ -99,3 +137,45 @@ verify code mClientID state = do
return $ case mData of
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
_ -> Nothing
mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper
mkToken u scopes state = do
pubKey <- atomically $ readTVar state >>= return . publicKey
now <- getCurrentTime
uuid <- nextRandom
let
lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable
lifetimeRT = nominalDay -- TODO make configurable
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
case encodedAT >> encodedRT of
Right _ -> do
let Jwt aToken = fromRight undefined encodedAT
Jwt rToken = fromRight undefined encodedRT
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
Left e -> error $ show e
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper)
renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of
Just t -> decodeToken t state >>= \case
Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body
now <- getCurrentTime
if now <= expiration jwt then return Nothing else do
mUser <- atomically . stateTVar state $ \s ->
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
in (key, s { activeTokens = tokens })
case mUser of
Just (u, scopes) -> Just <$> mkToken u scopes state
Nothing -> return Nothing
Left _ -> return Nothing
Nothing -> return Nothing

View File

@ -26,26 +26,20 @@ import Control.Monad.Trans.Error (Error(..))
import Control.Monad.Trans.Reader
import Data.Aeson
import Data.ByteString (ByteString (..), fromStrict, toStrict)
import Data.Either (fromRight)
import Data.ByteString (fromStrict)
import Data.List (find, elemIndex)
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
import Data.String (IsString (..))
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as Map
import GHC.Read (readPrec, lexP)
import Jose.Jwa
import Jose.Jwe
import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
import Jose.Jwk (generateRsaKeyPair, KeyUse(Enc), KeyId)
import Jose.Jwt hiding (decode, encode)
import Network.HTTP.Client (newManager, defaultManagerSettings)
@ -191,10 +185,9 @@ codeServer = handleCreds
----------------------
newtype ACode = ACode String deriving (Show)
newtype RToken = RToken String deriving (Show)
data ClientData = ClientData --TODO support other flows
{ authID :: Either ACode RToken
{ authID :: Either ACode JWTWrapper
, clientID :: Maybe String
, clientSecret :: Maybe String
, redirect :: Maybe String
@ -207,7 +200,8 @@ instance FromHttpApiData AuthFlow where
instance FromForm ClientData where
fromForm f = ClientData
<$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f) <|> (Right . RToken <$> parseUnique "refresh_token" f)))
<$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f)
<|> (Right <$> parseUnique "refresh_token" f)))
<*> parseMaybe "client_id" f
<*> parseMaybe "client_secret" f
<*> parseMaybe "redirect_uri" f
@ -216,30 +210,6 @@ instance Error Text where
strMsg = pack
data JWTWrapper = JWTW
{ acessToken :: String
, expiresIn :: NominalDiffTime
, refreshToken :: Maybe String
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW a e r) = object
[ "access_token" .= a
, "token_type" .= ("JWT" :: Text)
, "expires_in" .= fromEnum e
, "refresh_token" .= r ]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
<*> o .:? "refresh_token"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
type Token = "token"
:> ReqBody '[FormUrlEncoded] ClientData
:> Post '[JSON] JWTWrapper
@ -259,32 +229,14 @@ tokenEndpoint = provideToken
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
let (user, scopes) = fromJust mUser
token <- asks (mkToken @user @userData user scopes) >>= liftIO
token <- asks (mkToken @user user scopes) >>= liftIO
liftIO . putStrLn $ "token: " ++ show token
return token
Right (RToken rToken) -> undefined
mkToken :: forall user userData . UserData user userData
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
mkToken u scopes state = do
pubKey <- atomically $ readTVar state >>= return . publicKey
now <- getCurrentTime
uuid <- nextRandom
let
lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable
lifetimeRT = nominalDay -- TODO make configurable
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
case encodedAT >> encodedRT of
Right _ -> do
let Jwt aToken = fromRight undefined encodedAT
Jwt rToken = fromRight undefined encodedRT
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
Left e -> error $ show e
return token
Right jwtw -> do
mToken <- asks (renewToken @user jwtw) >>= liftIO
case mToken of
Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token
Nothing -> throwError $ err500 { errBody = "Invalid refresh token" }
----------------------
@ -313,7 +265,7 @@ userEndpoint = handleUserData
handleUserData jwtw = do
let mToken = stripPrefix "Bearer " jwtw
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format" }
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
token <- asks (decodeToken @user (fromJust mToken)) >>= liftIO
liftIO $ putStrLn "decoded token:" >> print token
case token of
Left e -> throwError $ err500 { errBody = fromString $ show e }
@ -327,11 +279,6 @@ userEndpoint = handleUserData
Nothing -> throwError $ err500 { errBody = "Unknown token" }
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
userListEndpoint = handleUserData
where