372 lines
14 KiB
Haskell
372 lines
14 KiB
Haskell
-- SPDX-FileCopyrightText: 2024 UniWorX Systems
|
|
-- SPDX-FileContributor: David Mosbach <david.mosbach@uniworx.de>
|
|
--
|
|
-- SPDX-License-Identifier: AGPL-3.0-or-later
|
|
|
|
{-# LANGUAGE DataKinds, TypeOperators, OverloadedStrings, ScopedTypeVariables, TypeApplications, RecordWildCards, AllowAmbiguousTypes #-}
|
|
|
|
module Server
|
|
{-( insecureOAuthMock'
|
|
, runMockServer
|
|
-- , runMockServer'
|
|
)-} where
|
|
|
|
import AuthCode
|
|
import LoginForm
|
|
import User
|
|
|
|
import Control.Applicative ((<|>))
|
|
import Control.Concurrent
|
|
import Control.Concurrent.STM (atomically)
|
|
import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
|
|
import Control.Exception (bracket)
|
|
import Control.Monad (unless, (>=>))
|
|
import Control.Monad.IO.Class
|
|
import Control.Monad.Trans.Error (Error(..))
|
|
import Control.Monad.Trans.Reader
|
|
|
|
import Data.Aeson
|
|
import Data.ByteString (fromStrict)
|
|
import Data.List (find, elemIndex)
|
|
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
|
|
import Data.String (IsString (..))
|
|
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
|
|
import qualified Data.Text as T
|
|
import Data.Text.Encoding.Base64
|
|
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
|
|
|
|
import qualified Data.Map.Strict as Map
|
|
|
|
import GHC.Read (readPrec, lexP)
|
|
|
|
import Jose.Jwk (generateRsaKeyPair, KeyUse(Enc), KeyId)
|
|
import Jose.Jwt hiding (decode, encode)
|
|
|
|
import Network.HTTP.Client (newManager, defaultManagerSettings)
|
|
import Network.Wai.Handler.Warp
|
|
|
|
import Servant
|
|
import Servant.Client
|
|
import Servant.API
|
|
|
|
import Text.ParserCombinators.ReadPrec (look, pfail)
|
|
|
|
import qualified Text.Read.Lex as Lex
|
|
|
|
import Web.Internal.FormUrlEncoded (FromForm(..), parseUnique, parseMaybe, Form(..))
|
|
|
|
|
|
|
|
data AuthClient = Client
|
|
{ ident :: Text
|
|
, secret :: Text
|
|
} deriving (Show, Eq)
|
|
|
|
trustedClients :: [AuthClient] -- TODO move to db
|
|
trustedClients = [Client "42" "shhh"]
|
|
|
|
data ResponseType = Code -- ^ authorisation code grant
|
|
| Token -- ^ implicit grant via access token
|
|
| IDToken -- ^ implicit grant via access token & ID token
|
|
deriving (Eq, Show)
|
|
instance Read ResponseType where
|
|
readPrec = do
|
|
Lex.Ident str <- lexP
|
|
Lex.EOF <- lexP
|
|
case str of
|
|
"code" -> return Code
|
|
"token" -> return Token
|
|
"id_token" -> return IDToken
|
|
_ -> pfail
|
|
|
|
------------------------------
|
|
---- Authorisation endpoint ----
|
|
------------------------------
|
|
|
|
type QScope = String
|
|
type QClient = String
|
|
type QResType = String
|
|
type QRedirect = Text
|
|
type QState = Text
|
|
type QAuth = Text
|
|
|
|
type QParam = QueryParam' [Required, Strict]
|
|
|
|
-- type Oauth2Params = QParam "scope" QScope
|
|
-- :> QParam "client_id" QClient
|
|
-- :> QParam "response_type" QResType
|
|
-- :> QParam "redirect_uri" QRedirect
|
|
-- :> QueryParam "state" QState
|
|
|
|
-- type ProtectedAuth user = BasicAuth "login" user :> "auth" :> Auth -- Prompts for username & password
|
|
-- type QuickAuth = "qauth" :> Auth -- Prompts for username only
|
|
type Auth = "auth"
|
|
:> QParam "scope" QScope
|
|
:> QParam "client_id" QClient
|
|
:> QParam "response_type" QResType
|
|
:> QParam "redirect_uri" QRedirect
|
|
:> QueryParam "state" QState
|
|
:> Get '[HTML] Html -- login
|
|
|
|
type AuthCode = "code"
|
|
:> HeaderR "Authorization" QAuth
|
|
:> HeaderR "OA2_Scope" QScope
|
|
:> HeaderR "OA2_Client_ID" QClient
|
|
:> HeaderR "OA2_Redirect_URI" QRedirect
|
|
:> Header "OA2_State" QState
|
|
:> Get '[JSON] Text -- returns auth code
|
|
|
|
|
|
type AuthHandler user = ReaderT (AuthState user) Handler
|
|
type AuthServer user a = ServerT a (AuthHandler user)
|
|
|
|
toHandler :: forall user userData a . UserData user userData => AuthState user -> AuthHandler user a -> Handler a
|
|
toHandler s h = runReaderT h s
|
|
|
|
loginServer :: forall user userData . UserData user userData => AuthServer user Auth
|
|
loginServer = handleAuth
|
|
where
|
|
handleAuth :: QScope
|
|
-> QClient
|
|
-> QResType
|
|
-> QRedirect
|
|
-> Maybe QState
|
|
-> AuthHandler user Html
|
|
handleAuth scopes client responseType url mState = do
|
|
let
|
|
responseType' = read @ResponseType responseType
|
|
headers = Map.fromList @Text @Text
|
|
[ ("OA2_Scope", pack scopes)
|
|
, ("OA2_Client_ID", pack client)
|
|
, ("OA2_Redirect_URI", url)]
|
|
headers' = if isJust mState then Map.insert "OA2_State" (fromJust mState) headers else headers
|
|
unless (responseType' == Code) $ throwError err500 { errBody = "Unsupported response type" }
|
|
return $ loginPage headers'
|
|
|
|
codeServer :: forall user userData . UserData user userData => AuthServer user AuthCode
|
|
codeServer = handleCreds
|
|
where
|
|
handleCreds :: QAuth
|
|
-> QScope
|
|
-> QClient
|
|
-> QRedirect
|
|
-> Maybe QState
|
|
-> AuthHandler user Text
|
|
handleCreds creds scopes client url mState = do
|
|
unless (isJust $ find (\c -> ident c == pack client) trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client
|
|
throwError $ err404 { errBody = "Not a trusted client."}
|
|
let
|
|
scopes' = map (readScope @user @userData) $ words scopes
|
|
[userName, password] = splitOn ":" $ decodeBase64Lenient creds
|
|
liftIO $ print userName
|
|
mUser <- liftIO $ lookupUser @user @userData userName (Just password)
|
|
unless (isJust mUser) $ throwError err500 { errBody = "Unknown user."}
|
|
let u = fromJust mUser
|
|
mAuthCode <- asks (genUnencryptedCode (AuthRequest client 600 u scopes') (unpack url)) >>= liftIO
|
|
liftIO $ print mAuthCode
|
|
liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
|
|
redirect $ addParams url mAuthCode mState
|
|
redirect :: Maybe Text -> AuthHandler user Text
|
|
redirect (Just url) = liftIO (print url) >> return url --throwError err303 { errHeaders = [("Location", url)]}
|
|
redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."}
|
|
addParams :: Text -> Maybe Text -> Maybe Text -> Maybe Text
|
|
addParams url Nothing _ = Nothing
|
|
addParams url (Just code) mState =
|
|
let urlParts = splitOn "?" url
|
|
(pre, post) = if length urlParts == 2 then (urlParts !! 0, urlParts !! 1) else (head urlParts, "")
|
|
rState = case mState of {Just s -> "&state=" <> (replace "/" "%2F" $ replace "=" "%3D" s); Nothing -> ""}
|
|
post' = if not (T.null post) then "&" <> T.tail post else post
|
|
in Just $ pre <> "?code=" <> code <> post' <> rState
|
|
|
|
|
|
|
|
----------------------
|
|
---- Token Endpoint ----
|
|
----------------------
|
|
|
|
newtype ACode = ACode String deriving (Show)
|
|
newtype RToken = RToken Text deriving (Show)
|
|
|
|
data ClientData = ClientData --TODO support other flows
|
|
{ authID :: Either ACode RToken
|
|
, grantType :: Text
|
|
, clientID :: Maybe String
|
|
, clientSecret :: Maybe String
|
|
, redirect :: Maybe String
|
|
} deriving Show
|
|
|
|
data AuthFlow = AuthFlow
|
|
instance FromHttpApiData AuthFlow where
|
|
parseQueryParam "authorization_code" = Right AuthFlow
|
|
parseQueryParam x = Left x
|
|
|
|
|
|
instance FromForm ClientData where
|
|
fromForm f = ClientData
|
|
<$> ((Left . ACode <$> parseUnique "code" f) <|> (parseMaybe @String "scope" f *> (Right . RToken <$> parseUnique "refresh_token" f)))
|
|
<*> parseUnique "grant_type" f
|
|
<*> parseMaybe "client_id" f
|
|
<*> parseMaybe "client_secret" f
|
|
<*> parseMaybe "redirect_uri" f
|
|
|
|
instance Error Text where
|
|
strMsg = pack
|
|
|
|
|
|
type Token = "token"
|
|
:> ReqBody '[FormUrlEncoded] Form --ClientData
|
|
:> Post '[JSON] JWTWrapper
|
|
|
|
|
|
tokenEndpoint :: forall user userData . UserData user userData => AuthServer user Token
|
|
tokenEndpoint = provideToken
|
|
where
|
|
provideToken :: Form -> AuthHandler user JWTWrapper
|
|
provideToken clienty = do
|
|
liftIO . putStrLn $ "Mock Server: received client data @ /token: " ++ show clienty
|
|
liftIO . print $ fromForm @ClientData clienty
|
|
let Right client = fromForm @ClientData clienty
|
|
unless (isNothing (clientID client >> clientSecret client)
|
|
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
|
|
throwError $ err500 { errBody = "Invalid client" }
|
|
case authID client of
|
|
Left (ACode authCode) -> do
|
|
unless (grantType client == "authorization_code") . throwError $ err500 { errBody = "Invalid grant_type" }
|
|
mUser <- asks (verify (pack authCode) (clientID client)) >>= liftIO -- TODO verify redirect url here
|
|
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
|
|
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
|
|
let (user, scopes) = fromJust mUser
|
|
token <- asks (mkToken @user user scopes) >>= liftIO
|
|
liftIO . putStrLn $ "token: " ++ show token
|
|
return token
|
|
Right (RToken jwtw) -> do
|
|
unless (grantType client == "refresh_token") . throwError $ err500 { errBody = "Invalid grant_type" }
|
|
liftIO $ putStrLn "... checking refresh token"
|
|
mToken <- asks (renewToken @user jwtw) >>= liftIO
|
|
liftIO $ putStrLn "woohoo"
|
|
case mToken of
|
|
Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token
|
|
Nothing -> throwError $ err500 { errBody = "Invalid refresh token" }
|
|
|
|
|
|
----------------------
|
|
---- Users Endpoint ----
|
|
----------------------
|
|
|
|
|
|
type Users = "users"
|
|
type HeaderR = Header' [Strict, Required]
|
|
|
|
type Me userData = Users
|
|
:> "me"
|
|
:> HeaderR "Authorization" Text
|
|
:> Get '[JSON] (Maybe userData)
|
|
|
|
type UserList userData = Users
|
|
:> "query"
|
|
:> HeaderR "Authorization" Text
|
|
:> QParam "id" Text
|
|
:> Get '[JSON] [userData]
|
|
|
|
|
|
verifyToken :: forall user userData . UserData user userData => Text -> AuthHandler user (Maybe (user, [Scope user]))
|
|
verifyToken jwtw = do
|
|
let mToken = stripPrefix "Bearer " jwtw
|
|
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format" }
|
|
token <- asks (decodeToken @user (fromJust mToken)) >>= liftIO
|
|
liftIO $ putStrLn "decoded token:" >> print token
|
|
case token of
|
|
Left e -> throwError $ err500 { errBody = fromString $ show e }
|
|
Right (Jwe (header, body)) -> do
|
|
let jwt = fromJust . decode @JWT $ fromStrict body
|
|
-- TODO check if token grants access, then read logged in user from cookie
|
|
liftIO $ print jwt
|
|
ask >>= liftIO . (atomically . readTVar >=> return . Map.lookup (jti jwt) . activeTokens)
|
|
|
|
|
|
userEndpoint :: forall user userData . UserData user userData => AuthServer user (Me userData)
|
|
userEndpoint = handleUserData
|
|
where
|
|
handleUserData :: Text -> AuthHandler user (Maybe userData)
|
|
handleUserData jwtw = do
|
|
mUser <- verifyToken @user @userData jwtw
|
|
case mUser of
|
|
Just (u, scopes) -> return . Just . mconcat $ map (userScope @user @userData u) scopes
|
|
Nothing -> throwError $ err500 { errBody = "Unknown token" }
|
|
|
|
|
|
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
|
|
userListEndpoint = handleUserData
|
|
where
|
|
handleUserData :: Text -> Text -> AuthHandler user [userData]
|
|
handleUserData jwtw userID = do
|
|
liftIO $ putStrLn "\nHOEHOEHOEHOEHOEHOHEHJBSDKFJBSDKGHBSDKGHBK\a\n"
|
|
mAdmin <- verifyToken @user @userData jwtw
|
|
unless (isJust mAdmin) . throwError $ err500 { errBody = "Unknown token" }
|
|
-- TODO check if this user is allowed query other users
|
|
mUser <- liftIO $ lookupUser @user @userData userID Nothing
|
|
case mUser of
|
|
Just u -> return [mconcat $ map (userScope @user @userData u) (snd $ fromJust mAdmin)] -- TODO support queries that fit for multiple users
|
|
Nothing -> throwError $ err500 { errBody = "This user does not exist" }
|
|
|
|
|
|
-------------------
|
|
---- Server Main ----
|
|
-------------------
|
|
|
|
type Routing user userData = Auth
|
|
:<|> AuthCode
|
|
:<|> Token
|
|
:<|> Me userData
|
|
:<|> UserList userData
|
|
|
|
routing :: forall user userData . UserData user userData => AuthServer user (Routing user userData)
|
|
routing = loginServer @user @userData
|
|
:<|> codeServer @user @userData
|
|
:<|> tokenEndpoint @user @userData
|
|
:<|> userEndpoint @user @userData
|
|
:<|> userListEndpoint @user @userData
|
|
|
|
|
|
|
|
-- insecureOAuthMock :: Application
|
|
-- insecureOAuthMock = authAPI `serve` exampleAuthServer
|
|
|
|
insecureOAuthMock' :: forall user userData . UserData user userData => AuthState user -> Application
|
|
insecureOAuthMock' s = serve authAPI $ hoistServer authAPI (toHandler @user @userData s) (routing @user @userData)
|
|
where
|
|
authAPI = Proxy @(Routing user userData)
|
|
|
|
-- authenticate :: [User] -> BasicAuthCheck User
|
|
-- authenticate users = BasicAuthCheck $ \authData -> do
|
|
-- let
|
|
-- (uEmail, uPass) = (,) <$> (decodeUtf8 . basicAuthUsername) <*> (decodeUtf8 . basicAuthPassword) $ authData
|
|
-- case (find (\u -> email u == uEmail) users) of
|
|
-- Nothing -> return NoSuchUser
|
|
-- Just u -> return $ if uPass == password u then Authorized u else BadPassword
|
|
|
|
-- frontend :: BasicAuthData -> ClientM (Map.Map Text Text)
|
|
-- frontend ba = client authAPI ba "[ID]" "42" "code" ""
|
|
|
|
runMockServer :: forall user userData . UserData user userData => Int -> IO ()
|
|
runMockServer port = do
|
|
state <- mkState @user @userData
|
|
run port $ insecureOAuthMock' @user @userData state
|
|
|
|
-- runMockServer' :: Int -> IO ()
|
|
-- runMockServer' port = do
|
|
-- mgr <- newManager defaultManagerSettings
|
|
-- state <- mkState
|
|
-- bracket (forkIO . run port $ insecureOAuthMock' testUsers state) killThread $ \_ ->
|
|
-- runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port ""))
|
|
-- >>= print
|
|
|
|
mkState :: forall user userData . UserData user userData => IO (AuthState user)
|
|
mkState = do
|
|
(publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing
|
|
let
|
|
activeCodes = Map.empty
|
|
activeTokens = Map.empty
|
|
newTVarIO State{..}
|
|
|