refresh token can be exchanged at token endpoint
This commit is contained in:
parent
d31c9abffc
commit
38e831c9b1
@ -3,27 +3,36 @@
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
|
||||
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-}
|
||||
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables, TypeApplications, LambdaCase #-}
|
||||
|
||||
module AuthCode
|
||||
( State(..)
|
||||
, AuthState
|
||||
, AuthRequest(..)
|
||||
, JWT(..)
|
||||
, JWTWrapper(..)
|
||||
, genUnencryptedCode
|
||||
, verify
|
||||
, mkToken
|
||||
, decodeToken
|
||||
, renewToken
|
||||
) where
|
||||
|
||||
import User
|
||||
|
||||
import Data.Aeson
|
||||
import Data.ByteString (ByteString (..), fromStrict, toStrict)
|
||||
import Data.Either (fromRight)
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (isJust, fromMaybe)
|
||||
import Data.Maybe (isJust, fromMaybe, fromJust)
|
||||
import Data.Time.Clock
|
||||
import Data.Text (pack, replace, Text)
|
||||
import Data.Text (pack, replace, Text, stripPrefix)
|
||||
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
||||
import Data.Text.Encoding.Base64
|
||||
import Data.UUID
|
||||
import Data.UUID.V4
|
||||
|
||||
import qualified Data.ByteString.Char8 as BS
|
||||
import qualified Data.Map.Strict as M
|
||||
|
||||
import Control.Concurrent (forkIO, threadDelay)
|
||||
@ -31,7 +40,12 @@ import Control.Concurrent.STM.TVar
|
||||
import Control.Monad (void, (>=>))
|
||||
import Control.Monad.STM
|
||||
|
||||
import Jose.Jwa
|
||||
import Jose.Jwe
|
||||
import Jose.Jwk (Jwk(..))
|
||||
import Jose.Jwt hiding (decode, encode)
|
||||
|
||||
import Servant.API (FromHttpApiData(..))
|
||||
|
||||
|
||||
data JWT = JWT
|
||||
@ -47,6 +61,31 @@ instance FromJSON JWT where
|
||||
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
|
||||
|
||||
|
||||
data JWTWrapper = JWTW
|
||||
{ acessToken :: String
|
||||
, expiresIn :: NominalDiffTime
|
||||
, refreshToken :: Maybe String
|
||||
} deriving (Show)
|
||||
|
||||
instance ToJSON JWTWrapper where
|
||||
toJSON (JWTW a e r) = object
|
||||
[ "access_token" .= a
|
||||
, "token_type" .= ("JWT" :: Text)
|
||||
, "expires_in" .= fromEnum e
|
||||
, "refresh_token" .= r ]
|
||||
|
||||
instance FromJSON JWTWrapper where
|
||||
parseJSON (Object o) = JWTW
|
||||
<$> o .: "access_token"
|
||||
<*> o .: "expires_in"
|
||||
<*> o .:? "refresh_token"
|
||||
|
||||
instance FromHttpApiData JWTWrapper where
|
||||
parseHeader bs = case decode (fromStrict bs) of
|
||||
Just x -> Right x
|
||||
Nothing -> Left "Invalid JWT wrapper"
|
||||
|
||||
|
||||
data AuthRequest user = AuthRequest
|
||||
{ client :: String
|
||||
, codeExpiration :: NominalDiffTime
|
||||
@ -81,12 +120,11 @@ genUnencryptedCode req url state = do
|
||||
then (False, s)
|
||||
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
|
||||
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
|
||||
|
||||
|
||||
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
|
||||
expire code time state = void . forkIO $ do
|
||||
threadDelay $ fromEnum time
|
||||
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
||||
where
|
||||
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
|
||||
expire code time state = void . forkIO $ do
|
||||
threadDelay $ fromEnum time
|
||||
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
||||
|
||||
|
||||
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
|
||||
@ -99,3 +137,45 @@ verify code mClientID state = do
|
||||
return $ case mData of
|
||||
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
|
||||
_ -> Nothing
|
||||
|
||||
|
||||
mkToken :: user -> [Scope user] -> AuthState user -> IO JWTWrapper
|
||||
mkToken u scopes state = do
|
||||
pubKey <- atomically $ readTVar state >>= return . publicKey
|
||||
now <- getCurrentTime
|
||||
uuid <- nextRandom
|
||||
let
|
||||
lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable
|
||||
lifetimeRT = nominalDay -- TODO make configurable
|
||||
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
|
||||
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
|
||||
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
|
||||
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
|
||||
case encodedAT >> encodedRT of
|
||||
Right _ -> do
|
||||
let Jwt aToken = fromRight undefined encodedAT
|
||||
Jwt rToken = fromRight undefined encodedRT
|
||||
atomically . modifyTVar state $ \s -> s { activeTokens = M.insert uuid (u, scopes) (activeTokens s) }
|
||||
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
|
||||
Left e -> error $ show e
|
||||
|
||||
decodeToken :: Text -> AuthState user -> IO (Either JwtError JwtContent)
|
||||
decodeToken token state = do
|
||||
prKey <- atomically $ readTVar state >>= return . privateKey
|
||||
jwkDecode prKey $ encodeUtf8 token
|
||||
|
||||
renewToken :: JWTWrapper -> AuthState user -> IO (Maybe JWTWrapper)
|
||||
renewToken (JWTW _ _ rt) state = case rt >>= stripPrefix "Bearer " . pack of
|
||||
Just t -> decodeToken t state >>= \case
|
||||
Right (Jwe (header, body)) -> do
|
||||
let jwt = fromJust . decode @JWT $ fromStrict body
|
||||
now <- getCurrentTime
|
||||
if now <= expiration jwt then return Nothing else do
|
||||
mUser <- atomically . stateTVar state $ \s ->
|
||||
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
|
||||
in (key, s { activeTokens = tokens })
|
||||
case mUser of
|
||||
Just (u, scopes) -> Just <$> mkToken u scopes state
|
||||
Nothing -> return Nothing
|
||||
Left _ -> return Nothing
|
||||
Nothing -> return Nothing
|
||||
|
||||
@ -26,26 +26,20 @@ import Control.Monad.Trans.Error (Error(..))
|
||||
import Control.Monad.Trans.Reader
|
||||
|
||||
import Data.Aeson
|
||||
import Data.ByteString (ByteString (..), fromStrict, toStrict)
|
||||
import Data.Either (fromRight)
|
||||
import Data.ByteString (fromStrict)
|
||||
import Data.List (find, elemIndex)
|
||||
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
|
||||
import Data.String (IsString (..))
|
||||
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
||||
import Data.Text.Encoding.Base64
|
||||
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
|
||||
import Data.UUID.V4
|
||||
|
||||
import qualified Data.ByteString.Char8 as BS
|
||||
import qualified Data.Map.Strict as Map
|
||||
|
||||
import GHC.Read (readPrec, lexP)
|
||||
|
||||
import Jose.Jwa
|
||||
import Jose.Jwe
|
||||
import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
|
||||
import Jose.Jwk (generateRsaKeyPair, KeyUse(Enc), KeyId)
|
||||
import Jose.Jwt hiding (decode, encode)
|
||||
|
||||
import Network.HTTP.Client (newManager, defaultManagerSettings)
|
||||
@ -191,10 +185,9 @@ codeServer = handleCreds
|
||||
----------------------
|
||||
|
||||
newtype ACode = ACode String deriving (Show)
|
||||
newtype RToken = RToken String deriving (Show)
|
||||
|
||||
data ClientData = ClientData --TODO support other flows
|
||||
{ authID :: Either ACode RToken
|
||||
{ authID :: Either ACode JWTWrapper
|
||||
, clientID :: Maybe String
|
||||
, clientSecret :: Maybe String
|
||||
, redirect :: Maybe String
|
||||
@ -207,7 +200,8 @@ instance FromHttpApiData AuthFlow where
|
||||
|
||||
instance FromForm ClientData where
|
||||
fromForm f = ClientData
|
||||
<$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f) <|> (Right . RToken <$> parseUnique "refresh_token" f)))
|
||||
<$> ((parseUnique @AuthFlow "grant_type" f) *> ((Left . ACode <$> parseUnique "code" f)
|
||||
<|> (Right <$> parseUnique "refresh_token" f)))
|
||||
<*> parseMaybe "client_id" f
|
||||
<*> parseMaybe "client_secret" f
|
||||
<*> parseMaybe "redirect_uri" f
|
||||
@ -216,30 +210,6 @@ instance Error Text where
|
||||
strMsg = pack
|
||||
|
||||
|
||||
data JWTWrapper = JWTW
|
||||
{ acessToken :: String
|
||||
, expiresIn :: NominalDiffTime
|
||||
, refreshToken :: Maybe String
|
||||
} deriving (Show)
|
||||
|
||||
instance ToJSON JWTWrapper where
|
||||
toJSON (JWTW a e r) = object
|
||||
[ "access_token" .= a
|
||||
, "token_type" .= ("JWT" :: Text)
|
||||
, "expires_in" .= fromEnum e
|
||||
, "refresh_token" .= r ]
|
||||
|
||||
instance FromJSON JWTWrapper where
|
||||
parseJSON (Object o) = JWTW
|
||||
<$> o .: "access_token"
|
||||
<*> o .: "expires_in"
|
||||
<*> o .:? "refresh_token"
|
||||
|
||||
instance FromHttpApiData JWTWrapper where
|
||||
parseHeader bs = case decode (fromStrict bs) of
|
||||
Just x -> Right x
|
||||
Nothing -> Left "Invalid JWT wrapper"
|
||||
|
||||
type Token = "token"
|
||||
:> ReqBody '[FormUrlEncoded] ClientData
|
||||
:> Post '[JSON] JWTWrapper
|
||||
@ -259,32 +229,14 @@ tokenEndpoint = provideToken
|
||||
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
|
||||
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
|
||||
let (user, scopes) = fromJust mUser
|
||||
token <- asks (mkToken @user @userData user scopes) >>= liftIO
|
||||
token <- asks (mkToken @user user scopes) >>= liftIO
|
||||
liftIO . putStrLn $ "token: " ++ show token
|
||||
return token
|
||||
Right (RToken rToken) -> undefined
|
||||
|
||||
|
||||
mkToken :: forall user userData . UserData user userData
|
||||
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
|
||||
mkToken u scopes state = do
|
||||
pubKey <- atomically $ readTVar state >>= return . publicKey
|
||||
now <- getCurrentTime
|
||||
uuid <- nextRandom
|
||||
let
|
||||
lifetimeAT = 120 :: NominalDiffTime -- TODO make configurable
|
||||
lifetimeRT = nominalDay -- TODO make configurable
|
||||
at = JWT "Oauth2MockServer" (lifetimeAT `addUTCTime` now) uuid
|
||||
rt = JWT "Oauth2MockServer" (lifetimeRT `addUTCTime` now) uuid
|
||||
encodedAT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode at)
|
||||
encodedRT <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode rt)
|
||||
case encodedAT >> encodedRT of
|
||||
Right _ -> do
|
||||
let Jwt aToken = fromRight undefined encodedAT
|
||||
Jwt rToken = fromRight undefined encodedRT
|
||||
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
|
||||
return $ JWTW (BS.unpack aToken) lifetimeAT (Just $ BS.unpack rToken)
|
||||
Left e -> error $ show e
|
||||
return token
|
||||
Right jwtw -> do
|
||||
mToken <- asks (renewToken @user jwtw) >>= liftIO
|
||||
case mToken of
|
||||
Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token
|
||||
Nothing -> throwError $ err500 { errBody = "Invalid refresh token" }
|
||||
|
||||
|
||||
----------------------
|
||||
@ -313,7 +265,7 @@ userEndpoint = handleUserData
|
||||
handleUserData jwtw = do
|
||||
let mToken = stripPrefix "Bearer " jwtw
|
||||
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format" }
|
||||
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
|
||||
token <- asks (decodeToken @user (fromJust mToken)) >>= liftIO
|
||||
liftIO $ putStrLn "decoded token:" >> print token
|
||||
case token of
|
||||
Left e -> throwError $ err500 { errBody = fromString $ show e }
|
||||
@ -327,11 +279,6 @@ userEndpoint = handleUserData
|
||||
Nothing -> throwError $ err500 { errBody = "Unknown token" }
|
||||
|
||||
|
||||
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
|
||||
decodeToken token state = do
|
||||
prKey <- atomically $ readTVar state >>= return . privateKey
|
||||
jwkDecode prKey $ encodeUtf8 token
|
||||
|
||||
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
|
||||
userListEndpoint = handleUserData
|
||||
where
|
||||
|
||||
Loading…
Reference in New Issue
Block a user