378 lines
14 KiB
Haskell
378 lines
14 KiB
Haskell
{-# LANGUAGE DataKinds, TypeOperators, OverloadedStrings, ScopedTypeVariables, TypeApplications, RecordWildCards, AllowAmbiguousTypes #-}
|
|
|
|
module Server
|
|
{-( insecureOAuthMock'
|
|
, runMockServer
|
|
-- , runMockServer'
|
|
)-} where
|
|
|
|
import AuthCode
|
|
import LoginForm
|
|
import User
|
|
|
|
import Control.Applicative ((<|>))
|
|
import Control.Concurrent
|
|
import Control.Concurrent.STM (atomically)
|
|
import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar)
|
|
import Control.Exception (bracket)
|
|
import Control.Monad (unless, (>=>))
|
|
import Control.Monad.IO.Class
|
|
import Control.Monad.Trans.Reader
|
|
|
|
import Data.Aeson
|
|
import Data.ByteString (ByteString (..), fromStrict, toStrict)
|
|
import Data.List (find, elemIndex)
|
|
import Data.Maybe (fromMaybe, fromJust, isJust, isNothing)
|
|
import Data.String (IsString (..))
|
|
import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words)
|
|
import qualified Data.Text as T
|
|
import Data.Text.Encoding (decodeUtf8, encodeUtf8)
|
|
import Data.Text.Encoding.Base64
|
|
import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime)
|
|
import Data.UUID.V4
|
|
|
|
import qualified Data.ByteString.Char8 as BS
|
|
import qualified Data.Map.Strict as Map
|
|
|
|
import GHC.Read (readPrec, lexP)
|
|
|
|
import Jose.Jwa
|
|
import Jose.Jwe
|
|
import Jose.Jwk (generateRsaKeyPair, generateSymmetricKey, KeyUse(Enc), KeyId)
|
|
import Jose.Jwt hiding (decode, encode)
|
|
|
|
import Network.HTTP.Client (newManager, defaultManagerSettings)
|
|
import Network.Wai.Handler.Warp
|
|
|
|
import Servant
|
|
import Servant.Client
|
|
import Servant.API
|
|
|
|
import Text.ParserCombinators.ReadPrec (look, pfail)
|
|
|
|
import qualified Text.Read.Lex as Lex
|
|
|
|
import Web.Internal.FormUrlEncoded (FromForm(..), parseUnique, parseMaybe)
|
|
|
|
|
|
|
|
data AuthClient = Client
|
|
{ ident :: Text
|
|
, secret :: Text
|
|
} deriving (Show, Eq)
|
|
|
|
trustedClients :: [AuthClient] -- TODO move to db
|
|
trustedClients = [Client "42" "shhh"]
|
|
|
|
data ResponseType = Code -- ^ authorisation code grant
|
|
| Token -- ^ implicit grant via access token
|
|
| IDToken -- ^ implicit grant via access token & ID token
|
|
deriving (Eq, Show)
|
|
instance Read ResponseType where
|
|
readPrec = do
|
|
Lex.Ident str <- lexP
|
|
Lex.EOF <- lexP
|
|
case str of
|
|
"code" -> return Code
|
|
"token" -> return Token
|
|
"id_token" -> return IDToken
|
|
_ -> pfail
|
|
|
|
------------------------------
|
|
---- Authorisation endpoint ----
|
|
------------------------------
|
|
|
|
type QScope = String
|
|
type QClient = String
|
|
type QResType = String
|
|
type QRedirect = Text
|
|
type QState = Text
|
|
type QAuth = Text
|
|
|
|
type QParam = QueryParam' [Required, Strict]
|
|
|
|
-- type Oauth2Params = QParam "scope" QScope
|
|
-- :> QParam "client_id" QClient
|
|
-- :> QParam "response_type" QResType
|
|
-- :> QParam "redirect_uri" QRedirect
|
|
-- :> QueryParam "state" QState
|
|
|
|
-- type ProtectedAuth user = BasicAuth "login" user :> "auth" :> Auth -- Prompts for username & password
|
|
-- type QuickAuth = "qauth" :> Auth -- Prompts for username only
|
|
type Auth = "auth"
|
|
:> QParam "scope" QScope
|
|
:> QParam "client_id" QClient
|
|
:> QParam "response_type" QResType
|
|
:> QParam "redirect_uri" QRedirect
|
|
:> QueryParam "state" QState
|
|
:> Get '[HTML] Html -- login
|
|
|
|
type AuthCode = "code"
|
|
:> HeaderR "Authorization" QAuth
|
|
:> HeaderR "OA2_Scope" QScope
|
|
:> HeaderR "OA2_Client_ID" QClient
|
|
:> HeaderR "OA2_Redirect_URI" QRedirect
|
|
:> Header "OA2_State" QState
|
|
:> Get '[JSON] Text -- returns auth code
|
|
|
|
|
|
type AuthHandler user = ReaderT (AuthState user) Handler
|
|
type AuthServer user a = ServerT a (AuthHandler user)
|
|
|
|
toHandler :: forall user userData a . UserData user userData => AuthState user -> AuthHandler user a -> Handler a
|
|
toHandler s h = runReaderT h s
|
|
|
|
loginServer :: forall user userData . UserData user userData => AuthServer user Auth
|
|
loginServer = handleAuth
|
|
where
|
|
handleAuth :: QScope
|
|
-> QClient
|
|
-> QResType
|
|
-> QRedirect
|
|
-> Maybe QState
|
|
-> AuthHandler user Html
|
|
handleAuth scopes client responseType url mState = do
|
|
let
|
|
responseType' = read @ResponseType responseType
|
|
headers = Map.fromList @Text @Text
|
|
[ ("OA2_Scope", pack scopes)
|
|
, ("OA2_Client_ID", pack client)
|
|
, ("OA2_Redirect_URI", url)]
|
|
headers' = if isJust mState then Map.insert "OA2_State" (fromJust mState) headers else headers
|
|
unless (responseType' == Code) $ throwError err500 { errBody = "Unsupported response type" }
|
|
return $ loginPage headers'
|
|
|
|
codeServer :: forall user userData . UserData user userData => AuthServer user AuthCode
|
|
codeServer = handleCreds
|
|
where
|
|
handleCreds :: QAuth
|
|
-> QScope
|
|
-> QClient
|
|
-> QRedirect
|
|
-> Maybe QState
|
|
-> AuthHandler user Text
|
|
handleCreds creds scopes client url mState = do
|
|
unless (isJust $ find (\c -> ident c == pack client) trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client
|
|
throwError $ err404 { errBody = "Not a trusted client."}
|
|
let
|
|
scopes' = map (readScope @user @userData) $ words scopes
|
|
[userName, password] = splitOn ":" $ decodeBase64Lenient creds
|
|
liftIO $ print userName
|
|
mUser <- liftIO $ lookupUser @user @userData userName password
|
|
unless (isJust mUser) $ throwError err500 { errBody = "Unknown user."}
|
|
let u = fromJust mUser
|
|
mAuthCode <- asks (genUnencryptedCode (AuthRequest client 600 u scopes') (unpack url)) >>= liftIO
|
|
liftIO $ print mAuthCode
|
|
liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes')
|
|
redirect $ addParams url mAuthCode mState
|
|
redirect :: Maybe Text -> AuthHandler user Text
|
|
redirect (Just url) = liftIO (print url) >> return url --throwError err303 { errHeaders = [("Location", url)]}
|
|
redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."}
|
|
addParams :: Text -> Maybe Text -> Maybe Text -> Maybe Text
|
|
addParams url Nothing _ = Nothing
|
|
addParams url (Just code) mState =
|
|
let urlParts = splitOn "?" url
|
|
(pre, post) = if length urlParts == 2 then (urlParts !! 0, urlParts !! 1) else (head urlParts, "")
|
|
rState = case mState of {Just s -> "&state=" <> (replace "/" "%2F" $ replace "=" "%3D" s); Nothing -> ""}
|
|
post' = if not (T.null post) then "&" <> T.tail post else post
|
|
in Just $ pre <> "?code=" <> code <> post' <> rState
|
|
|
|
|
|
|
|
----------------------
|
|
---- Token Endpoint ----
|
|
----------------------
|
|
|
|
|
|
data ClientData = ClientData --TODO support other flows
|
|
{ authCode :: String
|
|
, clientID :: Maybe String
|
|
, clientSecret :: Maybe String
|
|
, redirect :: Maybe String
|
|
} deriving Show
|
|
|
|
data AuthFlow = AuthFlow
|
|
instance FromHttpApiData AuthFlow where
|
|
parseQueryParam "authorization_code" = Right AuthFlow
|
|
parseQueryParam x = Left x
|
|
|
|
instance FromForm ClientData where
|
|
fromForm f = ClientData
|
|
<$> ((parseUnique @AuthFlow "grant_type" f) *> parseUnique "code" f)
|
|
<*> parseMaybe "client_id" f
|
|
<*> parseMaybe "client_secret" f
|
|
<*> parseMaybe "redirect_uri" f
|
|
|
|
|
|
data JWTWrapper = JWTW
|
|
{ token :: String
|
|
, expiresIn :: NominalDiffTime
|
|
} deriving (Show)
|
|
|
|
instance ToJSON JWTWrapper where
|
|
toJSON (JWTW t e) = object ["access_token" .= t, "token_type" .= ("JWT" :: Text), "expires_in" .= e]
|
|
|
|
instance FromJSON JWTWrapper where
|
|
parseJSON (Object o) = JWTW
|
|
<$> o .: "access_token"
|
|
<*> o .: "expires_in"
|
|
|
|
instance FromHttpApiData JWTWrapper where
|
|
parseHeader bs = case decode (fromStrict bs) of
|
|
Just x -> Right x
|
|
Nothing -> Left "Invalid JWT wrapper"
|
|
|
|
type Token = "token"
|
|
:> ReqBody '[FormUrlEncoded] ClientData
|
|
:> Post '[JSON] JWTWrapper
|
|
|
|
|
|
tokenEndpoint :: forall user userData . UserData user userData => AuthServer user Token
|
|
tokenEndpoint = provideToken
|
|
where
|
|
provideToken :: ClientData -> AuthHandler user JWTWrapper
|
|
provideToken client = do
|
|
unless (isNothing (clientID client >> clientSecret client)
|
|
|| Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) .
|
|
throwError $ err500 { errBody = "Invalid client" }
|
|
mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here
|
|
unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" }
|
|
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
|
|
let (user, scopes) = fromJust mUser
|
|
token <- asks (mkToken @user @userData user scopes) >>= liftIO
|
|
liftIO . putStrLn $ "token: " ++ show token
|
|
return token
|
|
|
|
|
|
mkToken :: forall user userData . UserData user userData
|
|
=> user -> [Scope user] -> AuthState user -> IO JWTWrapper
|
|
mkToken u scopes state = do
|
|
pubKey <- atomically $ readTVar state >>= return . publicKey
|
|
now <- getCurrentTime
|
|
uuid <- nextRandom
|
|
let
|
|
lifetime = nominalDay / 24 -- TODO make configurable
|
|
jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid
|
|
encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt)
|
|
case encoded of
|
|
Right (Jwt token) -> do
|
|
atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) }
|
|
return $ JWTW (BS.unpack token) lifetime
|
|
Left e -> error $ show e
|
|
|
|
|
|
----------------------
|
|
---- Users Endpoint ----
|
|
----------------------
|
|
|
|
|
|
type Users = "users"
|
|
type HeaderR = Header' [Strict, Required]
|
|
|
|
type Me userData = Users
|
|
:> "me"
|
|
:> HeaderR "Authorization" Text
|
|
:> Get '[JSON] (Maybe userData)
|
|
|
|
type UserList userData = Users
|
|
:> "query"
|
|
:> HeaderR "Authorization" Text
|
|
:> Get '[JSON] [userData] -- TODO support query params
|
|
|
|
|
|
userEndpoint :: forall user userData . UserData user userData => AuthServer user (Me userData)
|
|
userEndpoint = handleUserData
|
|
where
|
|
handleUserData :: Text -> AuthHandler user (Maybe userData)
|
|
handleUserData jwtw = do
|
|
let mToken = stripPrefix "Bearer " jwtw
|
|
unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"}
|
|
token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO
|
|
liftIO $ putStrLn "decoded token:" >> print token
|
|
case token of
|
|
Left e -> throwError $ err500 { errBody = fromString $ show e }
|
|
Right (Jwe (header, body)) -> do
|
|
let jwt = fromJust . decode @JWT $ fromStrict body
|
|
-- TODO check if token grants access, then read logged in user from cookie
|
|
liftIO $ print jwt
|
|
mUser <- ask >>= liftIO . (atomically . readTVar >=> return . Map.lookup (jti jwt) . activeTokens)
|
|
case mUser of
|
|
Just (u, scopes) -> return . Just . mconcat $ map (userScope @user @userData u) scopes
|
|
Nothing -> throwError $ err500 { errBody = "Unknown token" }
|
|
|
|
|
|
decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent)
|
|
decodeToken token state = do
|
|
prKey <- atomically $ readTVar state >>= return . privateKey
|
|
jwkDecode prKey $ encodeUtf8 token
|
|
|
|
userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData)
|
|
userListEndpoint = handleUserData
|
|
where
|
|
handleUserData :: Text -> AuthHandler user [userData]
|
|
handleUserData jwtw = do
|
|
undefined
|
|
|
|
|
|
-------------------
|
|
---- Server Main ----
|
|
-------------------
|
|
|
|
type Routing user userData = Auth
|
|
:<|> AuthCode
|
|
:<|> Token
|
|
:<|> Me userData
|
|
:<|> UserList userData
|
|
-- :<|> "qauth" :> Get '[HTML] Html
|
|
|
|
routing :: forall user userData . UserData user userData => AuthServer user (Routing user userData)
|
|
routing = loginServer @user @userData
|
|
:<|> codeServer @user @userData
|
|
:<|> tokenEndpoint @user @userData
|
|
:<|> userEndpoint @user @userData
|
|
:<|> userListEndpoint @user @userData
|
|
-- :<|> return (loginPage "/foobar")
|
|
|
|
|
|
|
|
-- insecureOAuthMock :: Application
|
|
-- insecureOAuthMock = authAPI `serve` exampleAuthServer
|
|
|
|
insecureOAuthMock' :: forall user userData . UserData user userData => AuthState user -> Application
|
|
insecureOAuthMock' s = serve authAPI $ hoistServer authAPI (toHandler @user @userData s) (routing @user @userData)
|
|
where
|
|
authAPI = Proxy @(Routing user userData)
|
|
|
|
-- authenticate :: [User] -> BasicAuthCheck User
|
|
-- authenticate users = BasicAuthCheck $ \authData -> do
|
|
-- let
|
|
-- (uEmail, uPass) = (,) <$> (decodeUtf8 . basicAuthUsername) <*> (decodeUtf8 . basicAuthPassword) $ authData
|
|
-- case (find (\u -> email u == uEmail) users) of
|
|
-- Nothing -> return NoSuchUser
|
|
-- Just u -> return $ if uPass == password u then Authorized u else BadPassword
|
|
|
|
-- frontend :: BasicAuthData -> ClientM (Map.Map Text Text)
|
|
-- frontend ba = client authAPI ba "[ID]" "42" "code" ""
|
|
|
|
runMockServer :: forall user userData . UserData user userData => Int -> IO ()
|
|
runMockServer port = do
|
|
state <- mkState @user @userData
|
|
run port $ insecureOAuthMock' @user @userData state
|
|
|
|
-- runMockServer' :: Int -> IO ()
|
|
-- runMockServer' port = do
|
|
-- mgr <- newManager defaultManagerSettings
|
|
-- state <- mkState
|
|
-- bracket (forkIO . run port $ insecureOAuthMock' testUsers state) killThread $ \_ ->
|
|
-- runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port ""))
|
|
-- >>= print
|
|
|
|
mkState :: forall user userData . UserData user userData => IO (AuthState user)
|
|
mkState = do
|
|
(publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing
|
|
let
|
|
activeCodes = Map.empty
|
|
activeTokens = Map.empty
|
|
newTVarIO State{..}
|
|
|