fixed url encoding

This commit is contained in:
David Mosbach 2024-03-25 03:51:04 +00:00
parent 7b995e6cff
commit 51a9a1acc1
2 changed files with 15 additions and 13 deletions

View File

@ -32,9 +32,9 @@ import Data.Map.Strict (Map)
import Data.Maybe (isJust, fromMaybe, fromJust, catMaybes) import Data.Maybe (isJust, fromMaybe, fromJust, catMaybes)
import Data.Time.Calendar import Data.Time.Calendar
import Data.Time.Clock import Data.Time.Clock
import Data.Text (pack, replace, Text, stripPrefix) import Data.Text (pack, Text, stripPrefix)
import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Data.Text.Encoding (decodeUtf8, encodeUtf8)
import Data.Text.Encoding.Base64 import Data.Text.Encoding.Base64.URL
import Data.UUID hiding (null) import Data.UUID hiding (null)
import Data.UUID.V4 import Data.UUID.V4
@ -152,7 +152,7 @@ genUnencryptedCode req url state = do
now <- getCurrentTime now <- getCurrentTime
let let
expiresAt = req.codeExpiration `addUTCTime` now expiresAt = req.codeExpiration `addUTCTime` now
simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt simpleCode = encodeBase64Unpadded . pack $ req.client <> url <> show now <> show expiresAt
success <- atomically . stateTVar state $ \s -> success <- atomically . stateTVar state $ \s ->
let mEntry = M.lookup simpleCode s.activeCodes let mEntry = M.lookup simpleCode s.activeCodes
in in
@ -237,17 +237,19 @@ renewToken :: forall user userData . UserData user userData
-> [Scope' user] -> [Scope' user]
-> Maybe Text -- ^ client_id -> Maybe Text -- ^ client_id
-> AuthState user -> AuthState user
-> IO (Maybe JWTWrapper) -- TODO more descriptive failures -> IO (Either Text JWTWrapper) -- TODO more descriptive failures
renewToken t scopes clientID state = decodeToken t state >>= \case renewToken t scopes clientID state = decodeToken t state >>= \case
Right (Jwe (header, body)) -> do Right (Jwe (header, body)) -> do
let jwt = fromJust . decode @JWT $ fromStrict body let jwt = fromJust . decode @JWT $ fromStrict body
now <- getCurrentTime now <- getCurrentTime
if now >= expiration jwt then return Nothing else do if now >= expiration jwt then return (Left "token expired") else do
mUser <- atomically . stateTVar state $ \s -> mUser <- atomically . stateTVar state $ \s ->
let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens
in (key, s { activeTokens = tokens }) in (key, s { activeTokens = tokens })
case mUser of case mUser of
Just (u, scopes', nonce) -> bool (pure Nothing) (Just <$> mkToken @user @userData (u, scopes, nonce) clientID state) (null $ scopes \\ scopes') Just (u, scopes', nonce) -> bool (pure $ Left "must not request new scopes")
Nothing -> return Nothing (Right <$> mkToken @user @userData (u, scopes, nonce) clientID state)
Left _ -> return Nothing (null $ scopes \\ scopes')
Nothing -> return $ Left "no user associated with this token"
Left _ -> return $ Left "could not decode token"

View File

@ -222,7 +222,7 @@ handleCreds creds scopes client url mState mNonce = do
addParams url (Just code) mState = addParams url (Just code) mState =
let urlParts = splitOn "?" url let urlParts = splitOn "?" url
(pre, post) = if length urlParts == 2 then (urlParts !! 0, urlParts !! 1) else (head urlParts, "") (pre, post) = if length urlParts == 2 then (urlParts !! 0, urlParts !! 1) else (head urlParts, "")
rState = case mState of {Just s -> "&state=" <> (replace "/" "%2F" $ replace "=" "%3D" s); Nothing -> ""} rState = case mState of {Just s -> "&" <> (decodeUtf8 . toStrict $ urlEncodeParams [("state", s)]) ; Nothing -> ""}
post' = if not (T.null post) then "&" <> T.tail post else post post' = if not (T.null post) then "&" <> T.tail post else post
in Just $ pre <> "?code=" <> code <> post' <> rState in Just $ pre <> "?code=" <> code <> post' <> rState
@ -292,10 +292,10 @@ tokenEndpoint = provideToken
liftIO . putStrLn $ "\aSCOPES: " ++ show scopes' liftIO . putStrLn $ "\aSCOPES: " ++ show scopes'
unless (grantType client == "refresh_token") . throwError $ err500 { errBody = "Invalid grant_type" } unless (grantType client == "refresh_token") . throwError $ err500 { errBody = "Invalid grant_type" }
liftIO $ putStrLn "... checking refresh token" liftIO $ putStrLn "... checking refresh token"
mToken <- asks (renewToken @user @userData jwtw (fromMaybe [] scopes') cid) >>= liftIO eToken <- asks (renewToken @user @userData jwtw (fromMaybe [] scopes') cid) >>= liftIO
case mToken of case eToken of
Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token Right token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token
Nothing -> throwError $ err500 { errBody = "Invalid refresh token" } Left err -> throwError $ err500 { errBody = fromStrict $ encodeUtf8 err }
---------------------- ----------------------