yesod-auth-oauth2/src/Yesod/Auth/OAuth2/Dispatch.hs
2023-10-31 11:38:51 -05:00

158 lines
5.0 KiB
Haskell

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Yesod.Auth.OAuth2.Dispatch
( FetchToken
, fetchAccessToken
, fetchAccessToken2
, FetchCreds
, dispatchAuthRequest
) where
import Control.Monad (unless)
import Control.Monad.Except (MonadError (..))
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Network.HTTP.Conduit (Manager)
import Network.OAuth.OAuth2.Compat
import URI.ByteString.Extension
import UnliftIO.Exception
import Yesod.Auth hiding (ServerError)
import Yesod.Auth.OAuth2.DispatchError
import Yesod.Auth.OAuth2.ErrorResponse
import Yesod.Auth.OAuth2.Random
import Yesod.Core hiding (ErrorResponse)
-- | How to fetch an @'OAuth2Token'@
--
-- This will be 'fetchAccessToken' or 'fetchAccessToken2'
type FetchToken =
Manager -> OAuth2 -> ExchangeToken -> IO (OAuth2Result Errors OAuth2Token)
-- | How to take an @'OAuth2Token'@ and retrieve user credentials
type FetchCreds m = Manager -> OAuth2Token -> IO (Creds m)
-- | Dispatch the various OAuth2 handshake routes
dispatchAuthRequest
:: Text
-- ^ Name
-> OAuth2
-- ^ Service details
-> FetchToken
-- ^ How to get a token
-> FetchCreds m
-- ^ How to get credentials
-> Text
-- ^ Method
-> [Text]
-- ^ Path pieces
-> AuthHandler m TypedContent
dispatchAuthRequest name oauth2 _ _ "GET" ["forward"] =
handleDispatchError $ dispatchForward name oauth2
dispatchAuthRequest name oauth2 getToken getCreds "GET" ["callback"] =
handleDispatchError $ dispatchCallback name oauth2 getToken getCreds
dispatchAuthRequest _ _ _ _ _ _ = notFound
-- | Handle @GET \/forward@
--
-- 1. Set a random CSRF token in our session
-- 2. Redirect to the Provider's authorization URL
dispatchForward
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> m TypedContent
dispatchForward name oauth2 = do
csrf <- setSessionCSRF $ tokenSessionKey name
oauth2' <- withCallbackAndState name oauth2 csrf
redirect $ toText $ authorizationUrl oauth2'
-- | Handle @GET \/callback@
--
-- 1. Verify the URL's CSRF token matches our session
-- 2. Use the code parameter to fetch an AccessToken for the Provider
-- 3. Use the AccessToken to construct a @'Creds'@ value for the Provider
dispatchCallback
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> FetchToken
-> FetchCreds site
-> m TypedContent
dispatchCallback name oauth2 getToken getCreds = do
onErrorResponse $ throwError . OAuth2HandshakeError
csrf <- verifySessionCSRF $ tokenSessionKey name
code <- requireGetParam "code"
manager <- authHttpManager
oauth2' <- withCallbackAndState name oauth2 csrf
token <-
either (throwError . OAuth2ResultError) pure
=<< liftIO (getToken manager oauth2' $ ExchangeToken code)
creds <-
liftIO (getCreds manager token)
`catch` (throwError . FetchCredsIOException)
`catch` (throwError . FetchCredsYesodOAuth2Exception)
setCredsRedirect creds
withCallbackAndState
:: (MonadError DispatchError m, MonadAuthHandler site m)
=> Text
-> OAuth2
-> Text
-> m OAuth2
withCallbackAndState name oauth2 csrf = do
callback <-
case oauth2RedirectUri oauth2 of
Just uri -> pure uri
Nothing -> do
uri <- ($ PluginR name ["callback"]) <$> getParentUrlRender
maybe (throwError $ InvalidCallbackUri uri) pure $ fromText uri
pure
oauth2
{ oauth2RedirectUri = Just callback
, oauth2AuthorizeEndpoint =
oauth2AuthorizeEndpoint oauth2 `withQuery` [("state", encodeUtf8 csrf)]
}
getParentUrlRender :: MonadHandler m => m (Route (SubHandlerSite m) -> Text)
getParentUrlRender = (.) <$> getUrlRender <*> getRouteToParent
-- | Set a random, ~64-byte value in the session
--
-- Some (but not all) providers decode a @+@ in the state token as a space when
-- sending it back to us. We don't expect this and fail. And if we did code for
-- it, we'd then fail on the providers that /don't/ do that.
--
-- Therefore, we just exclude @+@ in our tokens, which means this function may
-- return slightly fewer than 64 bytes.
setSessionCSRF :: MonadHandler m => Text -> m Text
setSessionCSRF sessionKey = do
csrfToken <- liftIO randomToken
csrfToken <$ setSession sessionKey csrfToken
where
randomToken = T.filter (/= '+') <$> randomText 64
-- | Verify the callback provided the same CSRF token as in our session
verifySessionCSRF
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
verifySessionCSRF sessionKey = do
token <- requireGetParam "state"
sessionToken <- lookupSession sessionKey
deleteSession sessionKey
token
<$ unless
(sessionToken == Just token)
(throwError $ InvalidStateToken sessionToken token)
requireGetParam
:: (MonadError DispatchError m, MonadHandler m) => Text -> m Text
requireGetParam key =
maybe (throwError $ MissingParameter key) pure =<< lookupGetParam key
tokenSessionKey :: Text -> Text
tokenSessionKey name = "_yesod_oauth2_" <> name