oauth2-mock-server/src/Server.hs
2024-01-16 21:33:19 +01:00

378 lines
14 KiB
Haskell

{-# LANGUAGE DataKinds, TypeOperators, OverloadedStrings, ScopedTypeVariables, TypeApplications, RecordWildCards, AllowAmbiguousTypes #-}
module Server
{-( insecureOAuthMock'
, runMockServer
-- , runMockServer'
)-} where
import AuthCode
import LoginForm
import User
import Control.Applicative ((<|>))
import Control.Concurrent
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
import Control.Exception (bracket)
import Control.Monad (unless, (>=>))
import Control.Monad.IO.Class
import Control.Monad.Trans.Reader
import Data.Aeson
import Data.ByteString (ByteString (..), fromStrict, toStrict)
import Data.List (find, elemIndex)
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
import Data.String (IsString (..))
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
import Data.UUID.V4
import qualified Data.ByteString.Char8 as BS
import qualified Data.Map.Strict as Map
import GHC.Read (readPrec, lexP)
import Jose.Jwa
import Jose.Jwe
import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
import Jose.Jwt hiding (decode, encode)
import Network.HTTP.Client (newManager, defaultManagerSettings)
import Network.Wai.Handler.Warp
import Servant
import Servant.Client
import Servant.API
import Text.ParserCombinators.ReadPrec (look, pfail)
import qualified Text.Read.Lex as Lex
import Web.Internal.FormUrlEncoded (FromForm(..), parseUnique, parseMaybe)
data AuthClient = Client
{ ident :: Text
, secret :: Text
} deriving (Show, Eq)
trustedClients :: [AuthClient] -- TODO move to db
trustedClients = [Client "42" "shhh"]
data ResponseType = Code -- ^ authorisation code grant
| Token -- ^ implicit grant via access token
| IDToken -- ^ implicit grant via access token & ID token
deriving (Eq, Show)
instance Read ResponseType where
readPrec = do
Lex.Ident str <- lexP
Lex.EOF <- lexP
case str of
"code" -> return Code
"token" -> return Token
"id_token" -> return IDToken
_ -> pfail
------------------------------
---- Authorisation endpoint ----
------------------------------
type QScope = String
type QClient = String
type QResType = String
type QRedirect = Text
type QState = Text
type QAuth = Text
type QParam = QueryParam' [Required, Strict]
-- type Oauth2Params = QParam "scope" QScope
-- :> QParam "client_id" QClient
-- :> QParam "response_type" QResType
-- :> QParam "redirect_uri" QRedirect
-- :> QueryParam "state" QState
-- type ProtectedAuth user = BasicAuth "login" user :> "auth" :> Auth -- Prompts for username & password
-- type QuickAuth = "qauth" :> Auth -- Prompts for username only
type Auth = "auth"
:> QParam "scope" QScope
:> QParam "client_id" QClient
:> QParam "response_type" QResType
:> QParam "redirect_uri" QRedirect
:> QueryParam "state" QState
:> Get '[HTML] Html -- login
type AuthCode = "code"
:> HeaderR "Authorization" QAuth
:> HeaderR "OA2_Scope" QScope
:> HeaderR "OA2_Client_ID" QClient
:> HeaderR "OA2_Redirect_URI" QRedirect
:> Header "OA2_State" QState
:> Get '[JSON] Text -- returns auth code
type AuthHandler user = ReaderT (AuthState user) Handler
type AuthServer user a = ServerT a (AuthHandler user)
toHandler :: forall user userData a . UserData user userData => AuthState user -> AuthHandler user a -> Handler a
toHandler s h = runReaderT h s
loginServer :: forall user userData . UserData user userData => AuthServer user Auth
loginServer = handleAuth
where
handleAuth :: QScope
-> QClient
-> QResType
-> QRedirect
-> Maybe QState
-> AuthHandler user Html
handleAuth scopes client responseType url mState = do
let
responseType' = read @ResponseType responseType
headers = Map.fromList @Text @Text
[ ("OA2_Scope", pack scopes)
, ("OA2_Client_ID", pack client)
, ("OA2_Redirect_URI", url)]
headers' = if isJust mState then Map.insert "OA2_State" (fromJust mState) headers else headers
unless (responseType' == Code) $ throwError err500 { errBody = "Unsupported response type" }
return $ loginPage headers'
codeServer :: forall user userData . UserData user userData => AuthServer user AuthCode
codeServer = handleCreds
where
handleCreds :: QAuth
-> QScope
-> QClient
-> QRedirect
-> Maybe QState
-> AuthHandler user Text
handleCreds creds scopes client url mState = do
unless (isJust $ find (\c -> ident c == pack client) trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client
throwError $ err404 { errBody = "Not a trusted client."}
let
scopes' = map (readScope @user @userData) $ words scopes
[userName, password] = splitOn ":" $ decodeBase64Lenient creds
liftIO $ print userName
mUser <- liftIO $ lookupUser @user @userData userName password
unless (isJust mUser) $ throwError err500 { errBody = "Unknown user."}
let u = fromJust mUser
mAuthCode <- asks (genUnencryptedCode (AuthRequest client 600 u scopes') (unpack url)) >>= liftIO
liftIO $ print mAuthCode
liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
redirect $ addParams url mAuthCode mState
redirect :: Maybe Text -> AuthHandler user Text
redirect (Just url) = liftIO (print url) >> return url --throwError err303 { errHeaders = [("Location", url)]}
redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."}
addParams :: Text -> Maybe Text -> Maybe Text -> Maybe Text
addParams url Nothing _ = Nothing
addParams url (Just code) mState =
let urlParts = splitOn "?" url
(pre, post) = if length urlParts == 2 then (urlParts !! 0, urlParts !! 1) else (head urlParts, "")
rState = case mState of {Just s -> "&state=" <> (replace "/" "%2F" $ replace "=" "%3D" s); Nothing -> ""}
post' = if not (T.null post) then "&" <> T.tail post else post
in Just $ pre <> "?code=" <> code <> post' <> rState
----------------------
---- Token Endpoint ----
----------------------
data ClientData = ClientData --TODO support other flows
{ authCode :: String
, clientID :: Maybe String
, clientSecret :: Maybe String
, redirect :: Maybe String
} deriving Show
data AuthFlow = AuthFlow
instance FromHttpApiData AuthFlow where
parseQueryParam "authorization_code" = Right AuthFlow
parseQueryParam x = Left x
instance FromForm ClientData where
fromForm f = ClientData
<$> ((parseUnique @AuthFlow "grant_type" f) *> parseUnique "code" f)
<*> parseMaybe "client_id" f
<*> parseMaybe "client_secret" f
<*> parseMaybe "redirect_uri" f
data JWTWrapper = JWTW
{ token :: String
, expiresIn :: NominalDiffTime
} deriving (Show)
instance ToJSON JWTWrapper where
toJSON (JWTW t e) = object ["access_token" .= t, "token_type" .= ("JWT" :: Text), "expires_in" .= e]
instance FromJSON JWTWrapper where
parseJSON (Object o) = JWTW
<$> o .: "access_token"
<*> o .: "expires_in"
instance FromHttpApiData JWTWrapper where
parseHeader bs = case decode (fromStrict bs) of
Just x -> Right x
Nothing -> Left "Invalid JWT wrapper"
type Token = "token"
:> ReqBody '[FormUrlEncoded] ClientData
:> Post '[JSON] JWTWrapper
tokenEndpoint :: forall user userData . UserData user userData => AuthServer user Token
tokenEndpoint = provideToken
where
provideToken :: ClientData -> AuthHandler user JWTWrapper
provideToken client = do
unless (isNothing (clientID client >> clientSecret client)
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
throwError $ err500 { errBody = "Invalid client" }
mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
let (user, scopes) = fromJust mUser
token <- asks (mkToken @user @userData user scopes) >>= liftIO
liftIO . putStrLn $ "token: " ++ show token
return token
mkToken :: forall user userData . UserData user userData
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
mkToken u scopes state = do
pubKey <- atomically $ readTVar state >>= return . publicKey
now <- getCurrentTime
uuid <- nextRandom
let
lifetime = nominalDay / 24 -- TODO make configurable
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid
encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt)
case encoded of
Right (Jwt token) -> do
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
return $ JWTW (BS.unpack token) lifetime
Left e -> error $ show e
----------------------
---- Users Endpoint ----
----------------------
type Users = "users"
type HeaderR = Header' [Strict, Required]
type Me userData = Users
:> "me"
:> HeaderR "Authorization" Text
:> Get '[JSON] (Maybe userData)
type UserList userData = Users
:> "query"
:> HeaderR "Authorization" Text
:> Get '[JSON] [userData] -- TODO support query params
userEndpoint :: forall user userData . UserData user userData => AuthServer user (Me userData)
userEndpoint = handleUserData
where
handleUserData :: Text -> AuthHandler user (Maybe userData)
handleUserData jwtw = do
let mToken = stripPrefix "Bearer " jwtw
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"}
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
liftIO $ putStrLn "decoded token:" >> print token
case token of
Left e -> throwError $ err500 { errBody = fromString $ show e }
Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body
-- TODO check if token grants access, then read logged in user from cookie
liftIO $ print jwt
mUser <- ask >>= liftIO . (atomically . readTVar >=> return . Map.lookup (jti jwt) . activeTokens)
case mUser of
Just (u, scopes) -> return . Just . mconcat $ map (userScope @user @userData u) scopes
Nothing -> throwError $ err500 { errBody = "Unknown token" }
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
decodeToken token state = do
prKey <- atomically $ readTVar state >>= return . privateKey
jwkDecode prKey $ encodeUtf8 token
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
userListEndpoint = handleUserData
where
handleUserData :: Text -> AuthHandler user [userData]
handleUserData jwtw = do
undefined
-------------------
---- Server Main ----
-------------------
type Routing user userData = Auth
:<|> AuthCode
:<|> Token
:<|> Me userData
:<|> UserList userData
-- :<|> "qauth" :> Get '[HTML] Html
routing :: forall user userData . UserData user userData => AuthServer user (Routing user userData)
routing = loginServer @user @userData
:<|> codeServer @user @userData
:<|> tokenEndpoint @user @userData
:<|> userEndpoint @user @userData
:<|> userListEndpoint @user @userData
-- :<|> return (loginPage "/foobar")
-- insecureOAuthMock :: Application
-- insecureOAuthMock = authAPI `serve` exampleAuthServer
insecureOAuthMock' :: forall user userData . UserData user userData => AuthState user -> Application
insecureOAuthMock' s = serve authAPI $ hoistServer authAPI (toHandler @user @userData s) (routing @user @userData)
where
authAPI = Proxy @(Routing user userData)
-- authenticate :: [User] -> BasicAuthCheck User
-- authenticate users = BasicAuthCheck $ \authData -> do
-- let
-- (uEmail, uPass) = (,) <$> (decodeUtf8 . basicAuthUsername) <*> (decodeUtf8 . basicAuthPassword) $ authData
-- case (find (\u -> email u == uEmail) users) of
-- Nothing -> return NoSuchUser
-- Just u -> return $ if uPass == password u then Authorized u else BadPassword
-- frontend :: BasicAuthData -> ClientM (Map.Map Text Text)
-- frontend ba = client authAPI ba "[ID]" "42" "code" ""
runMockServer :: forall user userData . UserData user userData => Int -> IO ()
runMockServer port = do
state <- mkState @user @userData
run port $ insecureOAuthMock' @user @userData state
-- runMockServer' :: Int -> IO ()
-- runMockServer' port = do
-- mgr <- newManager defaultManagerSettings
-- state <- mkState
-- bracket (forkIO . run port $ insecureOAuthMock' testUsers state) killThread $ \_ ->
-- runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port ""))
-- >>= print
mkState :: forall user userData . UserData user userData => IO (AuthState user)
mkState = do
(publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing
let
activeCodes = Map.empty
activeTokens = Map.empty
newTVarIO State{..}