diff --git a/src/AuthCode.hs b/src/AuthCode.hs index f836724..192ccdd 100644 --- a/src/AuthCode.hs +++ b/src/AuthCode.hs @@ -32,9 +32,9 @@ import Data.Map.Strict (Map) import Data.Maybe (isJust, fromMaybe, fromJust, catMaybes) import Data.Time.Calendar import Data.Time.Clock -import Data.Text (pack, replace, Text, stripPrefix) +import Data.Text (pack, Text, stripPrefix) import Data.Text.Encoding (decodeUtf8, encodeUtf8) -import Data.Text.Encoding.Base64 +import Data.Text.Encoding.Base64.URL import Data.UUID hiding (null) import Data.UUID.V4 @@ -152,7 +152,7 @@ genUnencryptedCode req url state = do now <- getCurrentTime let expiresAt = req.codeExpiration `addUTCTime` now - simpleCode = replace "/" "%2F" . replace "=" "%3D" . encodeBase64 . pack . filter (/= ' ') $ req.client <> url <> show now <> show expiresAt + simpleCode = encodeBase64Unpadded . pack $ req.client <> url <> show now <> show expiresAt success <- atomically . stateTVar state $ \s -> let mEntry = M.lookup simpleCode s.activeCodes in @@ -237,17 +237,19 @@ renewToken :: forall user userData . UserData user userData -> [Scope' user] -> Maybe Text -- ^ client_id -> AuthState user - -> IO (Maybe JWTWrapper) -- TODO more descriptive failures + -> IO (Either Text JWTWrapper) -- TODO more descriptive failures renewToken t scopes clientID state = decodeToken t state >>= \case Right (Jwe (header, body)) -> do let jwt = fromJust . decode @JWT $ fromStrict body now <- getCurrentTime - if now >= expiration jwt then return Nothing else do + if now >= expiration jwt then return (Left "token expired") else do mUser <- atomically . stateTVar state $ \s -> let (key, tokens) = M.updateLookupWithKey (\_ _ -> Nothing) (jti jwt) s.activeTokens in (key, s { activeTokens = tokens }) case mUser of - Just (u, scopes', nonce) -> bool (pure Nothing) (Just <$> mkToken @user @userData (u, scopes, nonce) clientID state) (null $ scopes \\ scopes') - Nothing -> return Nothing - Left _ -> return Nothing + Just (u, scopes', nonce) -> bool (pure $ Left "must not request new scopes") + (Right <$> mkToken @user @userData (u, scopes, nonce) clientID state) + (null $ scopes \\ scopes') + Nothing -> return $ Left "no user associated with this token" + Left _ -> return $ Left "could not decode token" diff --git a/src/Server.hs b/src/Server.hs index 7c9c8de..cc27314 100644 --- a/src/Server.hs +++ b/src/Server.hs @@ -222,7 +222,7 @@ handleCreds creds scopes client url mState mNonce = do addParams url (Just code) mState = let urlParts = splitOn "?" url (pre, post) = if length urlParts == 2 then (urlParts !! 0, urlParts !! 1) else (head urlParts, "") - rState = case mState of {Just s -> "&state=" <> (replace "/" "%2F" $ replace "=" "%3D" s); Nothing -> ""} + rState = case mState of {Just s -> "&" <> (decodeUtf8 . toStrict $ urlEncodeParams [("state", s)]) ; Nothing -> ""} post' = if not (T.null post) then "&" <> T.tail post else post in Just $ pre <> "?code=" <> code <> post' <> rState @@ -292,10 +292,10 @@ tokenEndpoint = provideToken liftIO . putStrLn $ "\aSCOPES: " ++ show scopes' unless (grantType client == "refresh_token") . throwError $ err500 { errBody = "Invalid grant_type" } liftIO $ putStrLn "... checking refresh token" - mToken <- asks (renewToken @user @userData jwtw (fromMaybe [] scopes') cid) >>= liftIO - case mToken of - Just token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token - Nothing -> throwError $ err500 { errBody = "Invalid refresh token" } + eToken <- asks (renewToken @user @userData jwtw (fromMaybe [] scopes') cid) >>= liftIO + case eToken of + Right token -> liftIO (putStrLn $ "refreshed token: " ++ show token) >> return token + Left err -> throwError $ err500 { errBody = fromStrict $ encodeUtf8 err } ----------------------