diff --git a/package.yaml b/package.yaml index 3d42327..d409dcd 100644 --- a/package.yaml +++ b/package.yaml @@ -32,11 +32,13 @@ library: - http-types >=0.8 && <0.13 - memory - microlens + - mtl - safe-exceptions - text >=0.7 && <2.0 - uri-bytestring - yesod-auth >=1.6.0 && <1.7 - yesod-core >=1.6.0 && <1.7 + - unliftio executables: yesod-auth-oauth2-example: diff --git a/src/UnliftIO/Except.hs b/src/UnliftIO/Except.hs new file mode 100644 index 0000000..728951e --- /dev/null +++ b/src/UnliftIO/Except.hs @@ -0,0 +1,12 @@ +{-# OPTIONS_GHC -Wno-orphans #-} + +module UnliftIO.Except + () where + +import Control.Monad.Except +import UnliftIO + +instance (MonadUnliftIO m, Exception e) => MonadUnliftIO (ExceptT e m) where + withRunInIO exceptToIO = ExceptT $ try $ do + withRunInIO $ \runInIO -> + exceptToIO (runInIO . (either throwIO pure <=< runExceptT)) diff --git a/src/Yesod/Auth/OAuth2/Dispatch.hs b/src/Yesod/Auth/OAuth2/Dispatch.hs index 746d35d..162e9ff 100644 --- a/src/Yesod/Auth/OAuth2/Dispatch.hs +++ b/src/Yesod/Auth/OAuth2/Dispatch.hs @@ -1,36 +1,30 @@ {-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TemplateHaskell #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} + module Yesod.Auth.OAuth2.Dispatch ( FetchToken , fetchAccessToken , fetchAccessToken2 , FetchCreds , dispatchAuthRequest - ) -where + ) where -import Control.Exception.Safe -import Control.Monad (unless, (<=<)) -import Crypto.Random (getRandomBytes) -import Data.ByteArray.Encoding (Base(Base64), convertToBase) -import Data.ByteString (ByteString) +import Control.Monad.Except import Data.Text (Text) import qualified Data.Text as T -import Data.Text.Encoding (decodeUtf8, encodeUtf8) +import Data.Text.Encoding (encodeUtf8) import Network.HTTP.Conduit (Manager) import Network.OAuth.OAuth2 import Network.OAuth.OAuth2.TokenRequest (Errors) +import UnliftIO.Exception import URI.ByteString.Extension import Yesod.Auth hiding (ServerError) +import Yesod.Auth.OAuth2.DispatchError import Yesod.Auth.OAuth2.ErrorResponse -import Yesod.Auth.OAuth2.Exception +import Yesod.Auth.OAuth2.Random import Yesod.Core hiding (ErrorResponse) -- | How to fetch an @'OAuth2Token'@ @@ -53,9 +47,9 @@ dispatchAuthRequest -> [Text] -- ^ Path pieces -> AuthHandler m TypedContent dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] = - dispatchForward name oauth2 + handleDispatchError $ dispatchForward name oauth2 dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] = - dispatchCallback name oauth2 getToken getCreds + handleDispatchError $ dispatchCallback name oauth2 getToken getCreds dispatchAuthRequest _ _ _ _ _ _ = notFound -- | Handle @GET \/forward@ @@ -63,7 +57,11 @@ dispatchAuthRequest _ _ _ _ _ _ = notFound -- 1. Set a random CSRF token in our session -- 2. Redirect to the Provider's authorization URL -- -dispatchForward :: Text -> OAuth2 -> AuthHandler m TypedContent +dispatchForward + :: (MonadError DispatchError m, MonadAuthHandler site m) + => Text + -> OAuth2 + -> m TypedContent dispatchForward name oauth2 = do csrf <- setSessionCSRF $ tokenSessionKey name oauth2' <- withCallbackAndState name oauth2 csrf @@ -76,75 +74,47 @@ dispatchForward name oauth2 = do -- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider -- dispatchCallback - :: Text + :: (MonadError DispatchError m, MonadAuthHandler site m) + => Text -> OAuth2 -> FetchToken - -> FetchCreds m - -> AuthHandler m TypedContent + -> FetchCreds site + -> m TypedContent dispatchCallback name oauth2 getToken getCreds = do csrf <- verifySessionCSRF $ tokenSessionKey name - onErrorResponse $ oauth2HandshakeError name + onErrorResponse $ throwError . OAuth2HandshakeError code <- requireGetParam "code" manager <- authHttpManager oauth2' <- withCallbackAndState name oauth2 csrf - token <- errLeft $ getToken manager oauth2' $ ExchangeToken code - creds <- errLeft $ tryFetchCreds $ getCreds manager token + token <- + errLeft OAuth2ResultError $ getToken manager oauth2' $ ExchangeToken + code + creds <- errLeft id $ tryFetchCreds $ getCreds manager token setCredsRedirect creds where - errLeft :: Show e => IO (Either e a) -> AuthHandler m a - errLeft = either (unexpectedError name) pure <=< liftIO + errLeft + :: (MonadIO m, MonadError e m) => (e' -> e) -> IO (Either e' a) -> m a + errLeft f = either (throwError . f) pure <=< liftIO --- | Handle an OAuth2 @'ErrorResponse'@ --- --- These are things coming from the OAuth2 provider such an Invalid Grant or --- Invalid Scope and /may/ be user-actionable. We've coded them to have an --- @'erUserMessage'@ that we are comfortable displaying to the user as part of --- the redirect, just in case. --- -oauth2HandshakeError :: Text -> ErrorResponse -> AuthHandler m a -oauth2HandshakeError name err = do - $(logError) $ "Handshake failure in " <> name <> " plugin: " <> tshow err - redirectMessage $ "OAuth2 handshake failure: " <> erUserMessage err - --- | Handle an unexpected error --- --- This would be some unexpected exception while processing the callback. --- Therefore, the user should see an opaque message and the details go only to --- the server logs. --- -unexpectedError :: Show e => Text -> e -> AuthHandler m a -unexpectedError name err = do - $(logError) $ "Error in " <> name <> " OAuth2 plugin: " <> tshow err - redirectMessage "Unexpected error logging in with OAuth2" - -redirectMessage :: Text -> AuthHandler m a -redirectMessage msg = do - toParent <- getRouteToParent - setMessage $ toHtml msg - redirect $ toParent LoginR - -tryFetchCreds :: IO a -> IO (Either SomeException a) +tryFetchCreds :: IO a -> IO (Either DispatchError a) tryFetchCreds f = (Right <$> f) - `catch` (\(ex :: IOException) -> pure $ Left $ toException ex) - `catch` (\(ex :: YesodOAuth2Exception) -> pure $ Left $ toException ex) + `catch` (pure . Left . FetchCredsIOException) + `catch` (pure . Left . FetchCredsYesodOAuth2Exception) -withCallbackAndState :: Text -> OAuth2 -> Text -> AuthHandler m OAuth2 +withCallbackAndState + :: (MonadError DispatchError m, MonadAuthHandler site m) + => Text + -> OAuth2 + -> Text + -> m OAuth2 withCallbackAndState name oauth2 csrf = do let url = PluginR name ["callback"] render <- getParentUrlRender let callbackText = render url - callback <- - maybe - (liftIO - $ throwString - $ "Invalid callback URI: " - <> T.unpack callbackText - <> ". Not using an absolute Approot?" - ) - pure - $ fromText callbackText + callback <- maybe (throwError $ InvalidCallbackUri callbackText) pure + $ fromText callbackText pure oauth2 { oauthCallback = Just callback @@ -169,40 +139,28 @@ setSessionCSRF :: MonadHandler m => Text -> m Text setSessionCSRF sessionKey = do csrfToken <- liftIO randomToken csrfToken <$ setSession sessionKey csrfToken - where - randomToken = - T.filter (/= '+') - . decodeUtf8 - . convertToBase @ByteString Base64 - <$> getRandomBytes 64 + where randomToken = T.filter (/= '+') <$> randomText 64 -- | Verify the callback provided the same CSRF token as in our session -verifySessionCSRF :: MonadHandler m => Text -> m Text +verifySessionCSRF + :: (MonadError DispatchError m, MonadHandler m) => Text -> m Text verifySessionCSRF sessionKey = do token <- requireGetParam "state" sessionToken <- lookupSession sessionKey deleteSession sessionKey - unless (sessionToken == Just token) $ do - $(logError) - $ "state token does not match. " - <> "Param: " - <> tshow token - <> "State: " - <> tshow sessionToken - permissionDenied "Invalid OAuth2 state token" + unless (sessionToken == Just token) $ throwError $ InvalidStateToken + sessionToken + token - return token + pure token -requireGetParam :: MonadHandler m => Text -> m Text +requireGetParam + :: (MonadError DispatchError m, MonadHandler m) => Text -> m Text requireGetParam key = do m <- lookupGetParam key - maybe errInvalidArgs return m - where - errInvalidArgs = invalidArgs ["The '" <> key <> "' parameter is required"] + maybe err return m + where err = throwError $ MissingParameter key tokenSessionKey :: Text -> Text tokenSessionKey name = "_yesod_oauth2_" <> name - -tshow :: Show a => a -> Text -tshow = T.pack . show diff --git a/src/Yesod/Auth/OAuth2/DispatchError.hs b/src/Yesod/Auth/OAuth2/DispatchError.hs new file mode 100644 index 0000000..b778968 --- /dev/null +++ b/src/Yesod/Auth/OAuth2/DispatchError.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} + +module Yesod.Auth.OAuth2.DispatchError + ( DispatchError(..) + , handleDispatchError + ) where + +import Control.Monad.Except +import Data.Text (Text, pack) +import Network.OAuth.OAuth2 +import Network.OAuth.OAuth2.TokenRequest (Errors) +import UnliftIO.Except () +import UnliftIO.Exception +import Yesod.Auth hiding (ServerError) +import Yesod.Auth.OAuth2.ErrorResponse +import Yesod.Auth.OAuth2.Exception +import Yesod.Auth.OAuth2.Random +import Yesod.Core hiding (ErrorResponse) + +data DispatchError + = MissingParameter Text + | InvalidStateToken (Maybe Text) Text + | InvalidCallbackUri Text + | OAuth2HandshakeError ErrorResponse + | OAuth2ResultError (OAuth2Error Errors) + | FetchCredsIOException IOException + | FetchCredsYesodOAuth2Exception YesodOAuth2Exception + deriving stock Show + deriving anyclass Exception + +-- | User-friendly message for any given 'DispatchError' +-- +-- Most of these are opaque to the user. The exception details are present for +-- the server logs. +-- +dispatchErrorMessage :: DispatchError -> Text +dispatchErrorMessage = \case + MissingParameter name -> + "Parameter '" <> name <> "' is required, but not present in the URL" + InvalidStateToken{} -> "State token is invalid, please try again" + InvalidCallbackUri{} + -> "Callback URI was not valid, this server may be misconfigured (no approot)" + OAuth2HandshakeError er -> "OAuth2 handshake failure: " <> erUserMessage er + OAuth2ResultError{} -> "Login failed, please try again" + FetchCredsIOException{} -> "Login failed, please try again" + FetchCredsYesodOAuth2Exception{} -> "Login failed, please try again" + +handleDispatchError + :: MonadAuthHandler site m + => ExceptT DispatchError m TypedContent + -> m TypedContent +handleDispatchError f = do + result <- runExceptT f + either onDispatchError pure result + +onDispatchError :: MonadAuthHandler site m => DispatchError -> m TypedContent +onDispatchError err = do + errorId <- liftIO $ randomText 16 + let suffix = " [errorId=" <> errorId <> "]" + $(logError) $ pack (displayException err) <> suffix + + let message = dispatchErrorMessage err <> suffix + messageValue = + object ["error" .= object ["id" .= errorId, "message" .= message]] + + loginR <- ($ LoginR) <$> getRouteToParent + + selectRep $ do + provideRep @_ @Html $ onErrorHtml loginR message + provideRep @_ @Value $ pure messageValue diff --git a/src/Yesod/Auth/OAuth2/Random.hs b/src/Yesod/Auth/OAuth2/Random.hs new file mode 100644 index 0000000..52b6072 --- /dev/null +++ b/src/Yesod/Auth/OAuth2/Random.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE TypeApplications #-} + +module Yesod.Auth.OAuth2.Random + ( randomText + ) where + +import Crypto.Random (MonadRandom, getRandomBytes) +import Data.ByteArray.Encoding (Base(Base64), convertToBase) +import Data.ByteString (ByteString) +import Data.Text (Text) +import Data.Text.Encoding (decodeUtf8) + +randomText + :: MonadRandom m + => Int + -- ^ Size in Bytes (note necessarily characters) + -> m Text +randomText size = + decodeUtf8 . convertToBase @ByteString Base64 <$> getRandomBytes size