store active tokens & scopes
This commit is contained in:
parent
c7989034b4
commit
2716ccf2c0
@ -20,6 +20,7 @@ extra-source-files:
|
||||
library
|
||||
exposed-modules:
|
||||
AuthCode
|
||||
DB
|
||||
Server
|
||||
User
|
||||
other-modules:
|
||||
@ -45,6 +46,7 @@ library
|
||||
, text
|
||||
, time
|
||||
, transformers
|
||||
, uuid
|
||||
, warp
|
||||
default-language: Haskell2010
|
||||
|
||||
@ -74,6 +76,7 @@ executable oauth2-mock-server-exe
|
||||
, text
|
||||
, time
|
||||
, transformers
|
||||
, uuid
|
||||
, warp
|
||||
default-language: Haskell2010
|
||||
|
||||
@ -104,5 +107,6 @@ test-suite oauth2-mock-server-test
|
||||
, text
|
||||
, time
|
||||
, transformers
|
||||
, uuid
|
||||
, warp
|
||||
default-language: Haskell2010
|
||||
|
||||
@ -36,6 +36,7 @@ dependencies:
|
||||
- jose-jwt
|
||||
- base64
|
||||
- http-api-data
|
||||
- uuid
|
||||
|
||||
ghc-options:
|
||||
- -Wall
|
||||
|
||||
@ -1,17 +1,23 @@
|
||||
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings #-}
|
||||
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-}
|
||||
|
||||
module AuthCode
|
||||
( State (..)
|
||||
( State(..)
|
||||
, AuthState
|
||||
, AuthRequest(..)
|
||||
, JWT(..)
|
||||
, genUnencryptedCode
|
||||
, verify
|
||||
) where
|
||||
|
||||
import User
|
||||
|
||||
import Data.Aeson
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (isJust, fromMaybe)
|
||||
import Data.Time.Clock
|
||||
import Data.Text (pack, replace, Text)
|
||||
import Data.Text.Encoding.Base64
|
||||
import Data.UUID
|
||||
|
||||
import qualified Data.Map.Strict as M
|
||||
|
||||
@ -23,42 +29,62 @@ import Control.Monad.STM
|
||||
import Jose.Jwk (Jwk(..))
|
||||
|
||||
|
||||
data JWT = JWT
|
||||
{ issuer :: Text
|
||||
, expiration :: UTCTime
|
||||
, jti :: UUID
|
||||
} deriving (Show, Eq)
|
||||
|
||||
instance ToJSON JWT where
|
||||
toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j]
|
||||
|
||||
instance FromJSON JWT where
|
||||
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti"
|
||||
|
||||
|
||||
data State = State
|
||||
{ activeCodes :: Map Text (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time)
|
||||
data AuthRequest user = AuthRequest
|
||||
{ client :: String
|
||||
, codeExpiration :: NominalDiffTime
|
||||
, user :: user
|
||||
, scopes :: [Scope user]
|
||||
}
|
||||
|
||||
|
||||
|
||||
data State user = State
|
||||
{ activeCodes :: Map Text (AuthRequest user)
|
||||
, activeTokens :: Map UUID (user, [Scope user])
|
||||
, publicKey :: Jwk
|
||||
, privateKey :: Jwk
|
||||
} deriving Show
|
||||
}
|
||||
|
||||
type AuthState = TVar State
|
||||
type AuthState user = TVar (State user)
|
||||
|
||||
genUnencryptedCode :: String
|
||||
genUnencryptedCode :: AuthRequest user
|
||||
-> String
|
||||
-> NominalDiffTime
|
||||
-> AuthState
|
||||
-> AuthState user
|
||||
-> IO (Maybe Text)
|
||||
genUnencryptedCode client url expiration state = do
|
||||
genUnencryptedCode req url state = do
|
||||
now <- getCurrentTime
|
||||
let
|
||||
expiresAt = expiration `addUTCTime` now
|
||||
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ client <> url <> show now <> show expiresAt
|
||||
expiresAt = req.codeExpiration `addUTCTime` now
|
||||
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt
|
||||
success <- atomically . stateTVar state $ \s ->
|
||||
let mEntry = M.lookup simpleCode s.activeCodes
|
||||
in
|
||||
if isJust mEntry
|
||||
then (False, s)
|
||||
else (True, s{ activeCodes = M.insert simpleCode (client, expiresAt) s.activeCodes })
|
||||
if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing
|
||||
else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes })
|
||||
if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing
|
||||
|
||||
|
||||
expire :: Text -> NominalDiffTime -> AuthState -> IO ()
|
||||
expire :: Text -> NominalDiffTime -> AuthState user -> IO ()
|
||||
expire code time state = void . forkIO $ do
|
||||
threadDelay $ fromEnum time
|
||||
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
||||
|
||||
|
||||
verify :: Text -> Maybe String -> AuthState -> IO Bool
|
||||
verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user]))
|
||||
verify code mClientID state = do
|
||||
now <- getCurrentTime
|
||||
mData <- atomically $ do
|
||||
@ -66,5 +92,5 @@ verify code mClientID state = do
|
||||
modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
||||
return result
|
||||
return $ case mData of
|
||||
Just (clientID', _) -> (fromMaybe clientID' mClientID) == clientID'
|
||||
_ -> False
|
||||
Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing
|
||||
_ -> Nothing
|
||||
|
||||
@ -12,7 +12,7 @@ import User
|
||||
import Control.Applicative ((<|>))
|
||||
import Control.Concurrent
|
||||
import Control.Concurrent.STM (atomically)
|
||||
import Control.Concurrent.STM.TVar (newTVarIO, readTVar)
|
||||
import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
|
||||
import Control.Exception (bracket)
|
||||
import Control.Monad (unless)
|
||||
import Control.Monad.IO.Class
|
||||
@ -26,6 +26,7 @@ import Data.String (IsString (..))
|
||||
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
|
||||
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
||||
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
|
||||
import Data.UUID.V4
|
||||
|
||||
import qualified Data.ByteString.Char8 as BS
|
||||
import qualified Data.Map.Strict as Map
|
||||
@ -103,13 +104,13 @@ type Auth user userData = BasicAuth "login" user
|
||||
-- type Insert = "insert" :> Post '[JSON] User
|
||||
|
||||
|
||||
type AuthHandler = ReaderT AuthState Handler
|
||||
type AuthServer a = ServerT a AuthHandler
|
||||
type AuthHandler user = ReaderT (AuthState user) Handler
|
||||
type AuthServer user a = ServerT a (AuthHandler user)
|
||||
|
||||
toHandler :: AuthState -> AuthHandler a -> Handler a
|
||||
toHandler :: forall user userData a . UserData user userData => AuthState user -> AuthHandler user a -> Handler a
|
||||
toHandler s h = runReaderT h s
|
||||
|
||||
authServer :: forall user userData . UserData user userData => AuthServer (Auth user userData)
|
||||
authServer :: forall user userData . UserData user userData => AuthServer user (Auth user userData)
|
||||
authServer = handleAuth
|
||||
where
|
||||
handleAuth :: user
|
||||
@ -118,18 +119,20 @@ authServer = handleAuth
|
||||
-> QResType
|
||||
-> QRedirect
|
||||
-> Maybe QState
|
||||
-> AuthHandler userData
|
||||
-> AuthHandler user userData
|
||||
handleAuth u scopes client responseType url mState = do
|
||||
unless (isJust $ find (\c -> ident c == pack client) trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client
|
||||
throwError $ err404 { errBody = "Not a trusted client."}
|
||||
let responseType' = read @ResponseType responseType
|
||||
let
|
||||
responseType' = read @ResponseType responseType
|
||||
scopes' = map (readScope @user @userData) $ words scopes
|
||||
liftIO $ print responseType'
|
||||
unless (responseType' == Code) $ throwError err500 { errBody = "Unsupported response type" }
|
||||
mAuthCode <- asks (genUnencryptedCode client url 600) >>= liftIO
|
||||
mAuthCode <- asks (genUnencryptedCode (AuthRequest client 600 u scopes') url) >>= liftIO
|
||||
liftIO $ print mAuthCode
|
||||
-- liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
|
||||
liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
|
||||
redirect $ addParams url mAuthCode mState
|
||||
redirect :: Maybe ByteString -> AuthHandler userData
|
||||
redirect :: Maybe ByteString -> AuthHandler user userData
|
||||
redirect (Just url) = liftIO (print url) >> throwError err303 { errHeaders = [("Location", url)]}
|
||||
redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."}
|
||||
addParams :: String -> Maybe Text -> Maybe Text -> Maybe ByteString
|
||||
@ -167,18 +170,6 @@ instance FromForm ClientData where
|
||||
<*> parseMaybe "redirect_uri" f
|
||||
|
||||
|
||||
|
||||
data JWT = JWT
|
||||
{ issuer :: Text
|
||||
, expiration :: UTCTime
|
||||
} deriving (Show, Eq)
|
||||
|
||||
instance ToJSON JWT where
|
||||
toJSON (JWT i e) = object ["iss" .= i, "exp" .= e]
|
||||
|
||||
instance FromJSON JWT where
|
||||
parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp"
|
||||
|
||||
data JWTWrapper = JWTW
|
||||
{ token :: String
|
||||
, expiresIn :: NominalDiffTime
|
||||
@ -202,32 +193,37 @@ type Token = "token"
|
||||
:> Post '[JSON] JWTWrapper
|
||||
|
||||
|
||||
tokenEndpoint :: AuthServer Token
|
||||
tokenEndpoint :: forall user userData . UserData user userData => AuthServer user Token
|
||||
tokenEndpoint = provideToken
|
||||
where
|
||||
provideToken :: ClientData -> AuthHandler JWTWrapper
|
||||
provideToken :: ClientData -> AuthHandler user JWTWrapper
|
||||
provideToken client = do
|
||||
unless (isNothing (clientID client >> clientSecret client)
|
||||
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
|
||||
throwError $ err500 { errBody = "Invalid client" }
|
||||
valid <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
|
||||
unless valid . throwError $ err500 { errBody = "Invalid authorisation code" }
|
||||
mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
|
||||
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
|
||||
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
|
||||
token <- asks mkToken >>= liftIO
|
||||
let (user, scopes) = fromJust mUser
|
||||
token <- asks (mkToken @user @userData user scopes) >>= liftIO
|
||||
liftIO . putStrLn $ "token: " ++ show token
|
||||
return token
|
||||
|
||||
|
||||
mkToken :: AuthState -> IO JWTWrapper
|
||||
mkToken state = do
|
||||
mkToken :: forall user userData . UserData user userData
|
||||
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
|
||||
mkToken u scopes state = do
|
||||
pubKey <- atomically $ readTVar state >>= return . publicKey
|
||||
now <- getCurrentTime
|
||||
uuid <- nextRandom
|
||||
let
|
||||
lifetime = nominalDay / 4 -- TODO make configurable
|
||||
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now)
|
||||
lifetime = nominalDay / 24 -- TODO make configurable
|
||||
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid
|
||||
encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt)
|
||||
case encoded of
|
||||
Right (Jwt token) -> return $ JWTW (BS.unpack token) lifetime
|
||||
Right (Jwt token) -> do
|
||||
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
|
||||
return $ JWTW (BS.unpack token) lifetime
|
||||
Left e -> error $ show e
|
||||
|
||||
|
||||
@ -250,20 +246,21 @@ type UserList userData = Users
|
||||
:> Get '[JSON] [userData] -- TODO support query params
|
||||
|
||||
|
||||
userEndpoint :: forall user userData . UserData user userData => AuthServer (Me userData)
|
||||
userEndpoint :: forall user userData . UserData user userData => AuthServer user (Me userData)
|
||||
userEndpoint = handleUserData
|
||||
where
|
||||
handleUserData :: Text -> AuthHandler userData
|
||||
handleUserData :: Text -> AuthHandler user userData
|
||||
handleUserData jwtw = do
|
||||
let mToken = stripPrefix "Bearer " jwtw
|
||||
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"}
|
||||
token <- asks (decodeToken (fromJust mToken)) >>= liftIO
|
||||
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
|
||||
liftIO $ putStrLn "decoded token:" >> print token
|
||||
case token of
|
||||
Left e -> throwError $ err500 { errBody = fromString $ show e }
|
||||
Right (Jwe (header, body)) -> do
|
||||
let jwt = fromJust . decode @JWT $ fromStrict body
|
||||
-- TODO check if token grants access, then read logged in user from cookie
|
||||
liftIO $ print jwt
|
||||
return mempty
|
||||
-- let
|
||||
-- scopes' = map (readScope @user @userData) $ words scopes
|
||||
@ -271,15 +268,15 @@ userEndpoint = handleUserData
|
||||
-- liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
|
||||
-- return uData
|
||||
|
||||
decodeToken :: Text -> AuthState -> IO (Either JwtError JwtContent)
|
||||
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
|
||||
decodeToken token state = do
|
||||
prKey <- atomically $ readTVar state >>= return . privateKey
|
||||
jwkDecode prKey $ encodeUtf8 token
|
||||
|
||||
userListEndpoint :: forall user userData . UserData user userData => AuthServer (UserList userData)
|
||||
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
|
||||
userListEndpoint = handleUserData
|
||||
where
|
||||
handleUserData :: Text -> AuthHandler [userData]
|
||||
handleUserData :: Text -> AuthHandler user [userData]
|
||||
handleUserData jwtw = do
|
||||
undefined
|
||||
|
||||
@ -293,13 +290,13 @@ type Routing user userData = Auth user userData
|
||||
:<|> Me userData
|
||||
:<|> UserList userData
|
||||
|
||||
routing :: forall user userData . UserData user userData => AuthServer (Routing user userData)
|
||||
routing :: forall user userData . UserData user userData => AuthServer user (Routing user userData)
|
||||
routing = authServer @user @userData
|
||||
:<|> tokenEndpoint
|
||||
:<|> tokenEndpoint @user @userData
|
||||
:<|> userEndpoint @user @userData
|
||||
:<|> userListEndpoint @user @userData
|
||||
|
||||
exampleAuthServer :: AuthServer (Routing User (Map.Map Text Text))
|
||||
exampleAuthServer :: AuthServer User (Routing User (Map.Map Text Text))
|
||||
exampleAuthServer = routing
|
||||
|
||||
|
||||
@ -309,8 +306,8 @@ authAPI = Proxy
|
||||
-- insecureOAuthMock :: Application
|
||||
-- insecureOAuthMock = authAPI `serve` exampleAuthServer
|
||||
|
||||
insecureOAuthMock' :: [User] -> AuthState -> Application
|
||||
insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler s) exampleAuthServer
|
||||
insecureOAuthMock' :: [User] -> AuthState User -> Application
|
||||
insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler @User @(Map.Map Text Text) s) exampleAuthServer
|
||||
where
|
||||
c = authenticate testUsers :. EmptyContext
|
||||
p = Proxy :: Proxy '[BasicAuthCheck User]
|
||||
@ -328,7 +325,7 @@ authenticate users = BasicAuthCheck $ \authData -> do
|
||||
|
||||
runMockServer :: Int -> IO ()
|
||||
runMockServer port = do
|
||||
state <- mkState
|
||||
state <- mkState @User @(Map.Map Text Text)
|
||||
run port $ insecureOAuthMock' testUsers state
|
||||
|
||||
-- runMockServer' :: Int -> IO ()
|
||||
@ -339,9 +336,11 @@ runMockServer port = do
|
||||
-- runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port ""))
|
||||
-- >>= print
|
||||
|
||||
mkState :: IO AuthState
|
||||
mkState :: forall user userData . UserData user userData => IO (AuthState user)
|
||||
mkState = do
|
||||
(publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing
|
||||
let activeCodes = Map.empty
|
||||
let
|
||||
activeCodes = Map.empty
|
||||
activeTokens = Map.empty
|
||||
newTVarIO State{..}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user