diff --git a/oauth2-mock-server.cabal b/oauth2-mock-server.cabal index 11a719a..94639e3 100644 --- a/oauth2-mock-server.cabal +++ b/oauth2-mock-server.cabal @@ -20,6 +20,7 @@ extra-source-files: library exposed-modules: AuthCode + DB Server User other-modules: @@ -45,6 +46,7 @@ library , text , time , transformers + , uuid , warp default-language: Haskell2010 @@ -74,6 +76,7 @@ executable oauth2-mock-server-exe , text , time , transformers + , uuid , warp default-language: Haskell2010 @@ -104,5 +107,6 @@ test-suite oauth2-mock-server-test , text , time , transformers + , uuid , warp default-language: Haskell2010 diff --git a/package.yaml b/package.yaml index 315b3f2..9790ba0 100644 --- a/package.yaml +++ b/package.yaml @@ -36,6 +36,7 @@ dependencies: - jose-jwt - base64 - http-api-data +- uuid ghc-options: - -Wall diff --git a/src/AuthCode.hs b/src/AuthCode.hs index 08ec313..1d23864 100644 --- a/src/AuthCode.hs +++ b/src/AuthCode.hs @@ -1,17 +1,23 @@ -{-# LANGUAGE OverloadedRecordDot, OverloadedStrings #-} +{-# LANGUAGE OverloadedRecordDot, OverloadedStrings, ScopedTypeVariables #-} module AuthCode -( State (..) +( State(..) , AuthState +, AuthRequest(..) +, JWT(..) , genUnencryptedCode , verify ) where +import User + +import Data.Aeson import Data.Map.Strict (Map) import Data.Maybe (isJust, fromMaybe) import Data.Time.Clock import Data.Text (pack, replace, Text) import Data.Text.Encoding.Base64 +import Data.UUID import qualified Data.Map.Strict as M @@ -23,42 +29,62 @@ import Control.Monad.STM import Jose.Jwk (Jwk(..)) +data JWT = JWT + { issuer :: Text + , expiration :: UTCTime + , jti :: UUID + } deriving (Show, Eq) + +instance ToJSON JWT where + toJSON (JWT i e j) = object ["iss" .= i, "exp" .= e, "jti" .= j] + +instance FromJSON JWT where + parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" <*> o .: "jti" -data State = State - { activeCodes :: Map Text (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time) +data AuthRequest user = AuthRequest + { client :: String + , codeExpiration :: NominalDiffTime + , user :: user + , scopes :: [Scope user] + } + + + +data State user = State + { activeCodes :: Map Text (AuthRequest user) + , activeTokens :: Map UUID (user, [Scope user]) , publicKey :: Jwk , privateKey :: Jwk - } deriving Show + } -type AuthState = TVar State +type AuthState user = TVar (State user) -genUnencryptedCode :: String +genUnencryptedCode :: AuthRequest user -> String - -> NominalDiffTime - -> AuthState + -> AuthState user -> IO (Maybe Text) -genUnencryptedCode client url expiration state = do +genUnencryptedCode req url state = do now <- getCurrentTime let - expiresAt = expiration `addUTCTime` now - simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ client <> url <> show now <> show expiresAt + expiresAt = req.codeExpiration `addUTCTime` now + simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt success <- atomically . stateTVar state $ \s -> let mEntry = M.lookup simpleCode s.activeCodes in if isJust mEntry then (False, s) - else (True, s{ activeCodes = M.insert simpleCode (client, expiresAt) s.activeCodes }) - if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing + else (True, s{ activeCodes = M.insert simpleCode req s.activeCodes }) + if success then expire simpleCode req.codeExpiration state >> return (Just simpleCode) else return Nothing -expire :: Text -> NominalDiffTime -> AuthState -> IO () +expire :: Text -> NominalDiffTime -> AuthState user -> IO () expire code time state = void . forkIO $ do threadDelay $ fromEnum time atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } -verify :: Text -> Maybe String -> AuthState -> IO Bool +verify :: Text -> Maybe String -> AuthState user -> IO (Maybe (user, [Scope user])) verify code mClientID state = do now <- getCurrentTime mData <- atomically $ do @@ -66,5 +92,5 @@ verify code mClientID state = do modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes } return result return $ case mData of - Just (clientID', _) -> (fromMaybe clientID' mClientID) == clientID' - _ -> False + Just (AuthRequest clientID' _ u s) -> if (fromMaybe clientID' mClientID) == clientID' then Just (u, s) else Nothing + _ -> Nothing diff --git a/src/DB.hs b/src/DB.hs new file mode 100644 index 0000000..cc1f99d --- /dev/null +++ b/src/DB.hs @@ -0,0 +1,4 @@ +module DB where + +import User + diff --git a/src/Server.hs b/src/Server.hs index f927d25..eb8dbe9 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -12,7 +12,7 @@ import User import Control.Applicative ((<|>)) import Control.Concurrent import Control.Concurrent.STM (atomically) -import Control.Concurrent.STM.TVar (newTVarIO, readTVar) +import Control.Concurrent.STM.TVar (newTVarIO, readTVar, modifyTVar) import Control.Exception (bracket) import Control.Monad (unless) import Control.Monad.IO.Class @@ -26,6 +26,7 @@ import Data.String (IsString (..)) import Data.Text hiding (elem, find, head, length, map, null, splitAt, tail, words) import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Time.Clock (NominalDiffTime (..), nominalDay, UTCTime(..), getCurrentTime, addUTCTime) +import Data.UUID.V4 import qualified Data.ByteString.Char8 as BS import qualified Data.Map.Strict as Map @@ -103,13 +104,13 @@ type Auth user userData = BasicAuth "login" user -- type Insert = "insert" :> Post '[JSON] User -type AuthHandler = ReaderT AuthState Handler -type AuthServer a = ServerT a AuthHandler +type AuthHandler user = ReaderT (AuthState user) Handler +type AuthServer user a = ServerT a (AuthHandler user) -toHandler :: AuthState -> AuthHandler a -> Handler a +toHandler :: forall user userData a . UserData user userData => AuthState user -> AuthHandler user a -> Handler a toHandler s h = runReaderT h s -authServer :: forall user userData . UserData user userData => AuthServer (Auth user userData) +authServer :: forall user userData . UserData user userData => AuthServer user (Auth user userData) authServer = handleAuth where handleAuth :: user @@ -118,18 +119,20 @@ authServer = handleAuth -> QResType -> QRedirect -> Maybe QState - -> AuthHandler userData + -> AuthHandler user userData handleAuth u scopes client responseType url mState = do unless (isJust $ find (\c -> ident c == pack client) trustedClients) . -- TODO fetch trusted clients from db | TODO also check if the redirect url really belongs to the client throwError $ err404 { errBody = "Not a trusted client."} - let responseType' = read @ResponseType responseType + let + responseType' = read @ResponseType responseType + scopes' = map (readScope @user @userData) $ words scopes liftIO $ print responseType' unless (responseType' == Code) $ throwError err500 { errBody = "Unsupported response type" } - mAuthCode <- asks (genUnencryptedCode client url 600) >>= liftIO + mAuthCode <- asks (genUnencryptedCode (AuthRequest client 600 u scopes') url) >>= liftIO liftIO $ print mAuthCode - -- liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes') + liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes') redirect $ addParams url mAuthCode mState - redirect :: Maybe ByteString -> AuthHandler userData + redirect :: Maybe ByteString -> AuthHandler user userData redirect (Just url) = liftIO (print url) >> throwError err303 { errHeaders = [("Location", url)]} redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."} addParams :: String -> Maybe Text -> Maybe Text -> Maybe ByteString @@ -167,18 +170,6 @@ instance FromForm ClientData where <*> parseMaybe "redirect_uri" f - -data JWT = JWT - { issuer :: Text - , expiration :: UTCTime - } deriving (Show, Eq) - -instance ToJSON JWT where - toJSON (JWT i e) = object ["iss" .= i, "exp" .= e] - -instance FromJSON JWT where - parseJSON (Object o) = JWT <$> o .: "iss" <*> o .: "exp" - data JWTWrapper = JWTW { token :: String , expiresIn :: NominalDiffTime @@ -202,32 +193,37 @@ type Token = "token" :> Post '[JSON] JWTWrapper -tokenEndpoint :: AuthServer Token +tokenEndpoint :: forall user userData . UserData user userData => AuthServer user Token tokenEndpoint = provideToken where - provideToken :: ClientData -> AuthHandler JWTWrapper + provideToken :: ClientData -> AuthHandler user JWTWrapper provideToken client = do unless (isNothing (clientID client >> clientSecret client) || Client (pack . fromJust $ clientID client) (pack . fromJust $ clientSecret client) `elem` trustedClients) . throwError $ err500 { errBody = "Invalid client" } - valid <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here - unless valid . throwError $ err500 { errBody = "Invalid authorisation code" } + mUser <- asks (verify (pack $ authCode client) (clientID client)) >>= liftIO -- TODO verify redirect url here + unless (isJust mUser) . throwError $ err500 { errBody = "Invalid authorisation code" } -- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay} - token <- asks mkToken >>= liftIO + let (user, scopes) = fromJust mUser + token <- asks (mkToken @user @userData user scopes) >>= liftIO liftIO . putStrLn $ "token: " ++ show token return token -mkToken :: AuthState -> IO JWTWrapper -mkToken state = do +mkToken :: forall user userData . UserData user userData + => user -> [Scope user] -> AuthState user -> IO JWTWrapper +mkToken u scopes state = do pubKey <- atomically $ readTVar state >>= return . publicKey now <- getCurrentTime + uuid <- nextRandom let - lifetime = nominalDay / 4 -- TODO make configurable - jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) + lifetime = nominalDay / 24 -- TODO make configurable + jwt = JWT "Oauth2MockServer" (lifetime `addUTCTime` now) uuid encoded <- jwkEncode RSA_OAEP_256 A128GCM pubKey (Nested . Jwt . toStrict $ encode jwt) case encoded of - Right (Jwt token) -> return $ JWTW (BS.unpack token) lifetime + Right (Jwt token) -> do + atomically . modifyTVar state $ \s -> s { activeTokens = Map.insert uuid (u, scopes) (activeTokens s) } + return $ JWTW (BS.unpack token) lifetime Left e -> error $ show e @@ -250,20 +246,21 @@ type UserList userData = Users :> Get '[JSON] [userData] -- TODO support query params -userEndpoint :: forall user userData . UserData user userData => AuthServer (Me userData) +userEndpoint :: forall user userData . UserData user userData => AuthServer user (Me userData) userEndpoint = handleUserData where - handleUserData :: Text -> AuthHandler userData + handleUserData :: Text -> AuthHandler user userData handleUserData jwtw = do let mToken = stripPrefix "Bearer " jwtw unless (isJust mToken) . throwError $ err500 { errBody = "Invalid token format"} - token <- asks (decodeToken (fromJust mToken)) >>= liftIO + token <- asks (decodeToken @user @userData (fromJust mToken)) >>= liftIO liftIO $ putStrLn "decoded token:" >> print token case token of Left e -> throwError $ err500 { errBody = fromString $ show e } Right (Jwe (header, body)) -> do let jwt = fromJust . decode @JWT $ fromStrict body -- TODO check if token grants access, then read logged in user from cookie + liftIO $ print jwt return mempty -- let -- scopes' = map (readScope @user @userData) $ words scopes @@ -271,15 +268,15 @@ userEndpoint = handleUserData -- liftIO . putStrLn $ "user: " ++ show u ++ " | scopes: " ++ show (map (showScope @user @userData) scopes') -- return uData -decodeToken :: Text -> AuthState -> IO (Either JwtError JwtContent) +decodeToken :: forall user userData . UserData user userData => Text -> AuthState user -> IO (Either JwtError JwtContent) decodeToken token state = do prKey <- atomically $ readTVar state >>= return . privateKey jwkDecode prKey $ encodeUtf8 token -userListEndpoint :: forall user userData . UserData user userData => AuthServer (UserList userData) +userListEndpoint :: forall user userData . UserData user userData => AuthServer user (UserList userData) userListEndpoint = handleUserData where - handleUserData :: Text -> AuthHandler [userData] + handleUserData :: Text -> AuthHandler user [userData] handleUserData jwtw = do undefined @@ -293,13 +290,13 @@ type Routing user userData = Auth user userData :<|> Me userData :<|> UserList userData -routing :: forall user userData . UserData user userData => AuthServer (Routing user userData) +routing :: forall user userData . UserData user userData => AuthServer user (Routing user userData) routing = authServer @user @userData - :<|> tokenEndpoint + :<|> tokenEndpoint @user @userData :<|> userEndpoint @user @userData :<|> userListEndpoint @user @userData -exampleAuthServer :: AuthServer (Routing User (Map.Map Text Text)) +exampleAuthServer :: AuthServer User (Routing User (Map.Map Text Text)) exampleAuthServer = routing @@ -309,8 +306,8 @@ authAPI = Proxy -- insecureOAuthMock :: Application -- insecureOAuthMock = authAPI `serve` exampleAuthServer -insecureOAuthMock' :: [User] -> AuthState -> Application -insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler s) exampleAuthServer +insecureOAuthMock' :: [User] -> AuthState User -> Application +insecureOAuthMock' testUsers s = serveWithContext authAPI c $ hoistServerWithContext authAPI p (toHandler @User @(Map.Map Text Text) s) exampleAuthServer where c = authenticate testUsers :. EmptyContext p = Proxy :: Proxy '[BasicAuthCheck User] @@ -328,7 +325,7 @@ authenticate users = BasicAuthCheck $ \authData -> do runMockServer :: Int -> IO () runMockServer port = do - state <- mkState + state <- mkState @User @(Map.Map Text Text) run port $ insecureOAuthMock' testUsers state -- runMockServer' :: Int -> IO () @@ -339,9 +336,11 @@ runMockServer port = do -- runClientM (frontend $ BasicAuthData "foo@bar.com" "0000") (mkClientEnv mgr (BaseUrl Http "localhost" port "")) -- >>= print -mkState :: IO AuthState +mkState :: forall user userData . UserData user userData => IO (AuthState user) mkState = do (publicKey, privateKey) <- generateRsaKeyPair 256 (KeyId "Oauth2MockKey") Enc Nothing - let activeCodes = Map.empty + let + activeCodes = Map.empty + activeTokens = Map.empty newTVarIO State{..}