base64 encoding for auth code
This commit is contained in:
parent
83dacacf56
commit
9e61d356e2
@ -32,6 +32,7 @@ library
|
||||
build-depends:
|
||||
aeson
|
||||
, base >=4.7 && <5
|
||||
, base64
|
||||
, bytestring
|
||||
, containers
|
||||
, http-client
|
||||
@ -58,6 +59,7 @@ executable oauth2-mock-server-exe
|
||||
build-depends:
|
||||
aeson
|
||||
, base >=4.7 && <5
|
||||
, base64
|
||||
, bytestring
|
||||
, containers
|
||||
, http-client
|
||||
@ -86,6 +88,7 @@ test-suite oauth2-mock-server-test
|
||||
build-depends:
|
||||
aeson
|
||||
, base >=4.7 && <5
|
||||
, base64
|
||||
, bytestring
|
||||
, containers
|
||||
, http-client
|
||||
|
||||
@ -34,6 +34,7 @@ dependencies:
|
||||
- transformers
|
||||
- bytestring
|
||||
- jose-jwt
|
||||
- base64
|
||||
|
||||
ghc-options:
|
||||
- -Wall
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
{-# LANGUAGE OverloadedRecordDot #-}
|
||||
{-# LANGUAGE OverloadedRecordDot, OverloadedStrings #-}
|
||||
|
||||
module AuthCode
|
||||
( State (..)
|
||||
@ -10,6 +10,8 @@ module AuthCode
|
||||
import Data.Map.Strict (Map)
|
||||
import Data.Maybe (isJust)
|
||||
import Data.Time.Clock
|
||||
import Data.Text (pack, replace, Text)
|
||||
import Data.Text.Encoding.Base64
|
||||
|
||||
import qualified Data.Map.Strict as M
|
||||
|
||||
@ -24,7 +26,7 @@ import Jose.Jwk (Jwk(..))
|
||||
|
||||
|
||||
data State = State
|
||||
{ activeCodes :: Map String (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time)
|
||||
{ activeCodes :: Map Text (String, UTCTime) -- ^ maps auth codes to (client ID, expiration time)
|
||||
, publicKey :: Jwk
|
||||
, privateKey :: Jwk
|
||||
} deriving Show
|
||||
@ -35,12 +37,12 @@ genUnencryptedCode :: String
|
||||
-> String
|
||||
-> NominalDiffTime
|
||||
-> AuthState
|
||||
-> IO (Maybe String)
|
||||
-> IO (Maybe Text)
|
||||
genUnencryptedCode client url expiration state = do
|
||||
now <- getCurrentTime
|
||||
let
|
||||
expiresAt = expiration `addUTCTime` now
|
||||
simpleCode = filter (/= ' ') $ client <> url <> show now <> show expiresAt
|
||||
simpleCode = replace "=" "" . encodeBase64 . pack . filter (/= ' ') $ client <> url <> show now <> show expiresAt
|
||||
success <- atomically . stateTVar state $ \s ->
|
||||
let mEntry = M.lookup simpleCode s.activeCodes
|
||||
in
|
||||
@ -50,13 +52,13 @@ genUnencryptedCode client url expiration state = do
|
||||
if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing
|
||||
|
||||
|
||||
expire :: String -> NominalDiffTime -> AuthState -> IO ()
|
||||
expire :: Text -> NominalDiffTime -> AuthState -> IO ()
|
||||
expire code time state = void . forkIO $ do
|
||||
threadDelay $ fromEnum time
|
||||
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
|
||||
|
||||
|
||||
verify :: String -> String -> AuthState -> IO Bool
|
||||
verify :: Text -> String -> AuthState -> IO Bool
|
||||
verify code clientID state = do
|
||||
now <- getCurrentTime
|
||||
mData <- atomically $ do
|
||||
|
||||
@ -130,14 +130,14 @@ authServer = handleAuth
|
||||
redirect :: Maybe ByteString -> AuthHandler userData
|
||||
redirect (Just url) = liftIO (print url) >> throwError err303 { errHeaders = [("Location", url)]}
|
||||
redirect Nothing = throwError err500 { errBody = "Could not generate authorisation code."}
|
||||
addParams :: String -> Maybe String -> Maybe String -> Maybe ByteString
|
||||
addParams :: String -> Maybe Text -> Maybe String -> Maybe ByteString
|
||||
addParams url Nothing _ = Nothing
|
||||
addParams url (Just code) mState =
|
||||
let qPos = fromMaybe (length url) $ elemIndex '?' url
|
||||
(pre, post) = splitAt qPos url
|
||||
rState = case mState of {Just s -> "&state=" ++ s; Nothing -> ""}
|
||||
post' = if not (null post) then '&' : tail post else post
|
||||
in Just . fromString $ pre ++ "?code=" ++ code ++ post' ++ rState
|
||||
in Just . fromString $ pre ++ "?code=" ++ (unpack code) ++ post' ++ rState
|
||||
|
||||
|
||||
----------------------
|
||||
@ -217,7 +217,7 @@ tokenEndpoint = provideToken
|
||||
AuthCode -> do
|
||||
unless (Client (pack $ clientID client) (pack $ clientSecret client) `elem` trustedClients) .
|
||||
throwError $ err500 { errBody = "Invalid client" }
|
||||
valid <- asks (verify (grant client) (clientID client)) >>= liftIO
|
||||
valid <- asks (verify (pack $ grant client) (clientID client)) >>= liftIO
|
||||
unless valid . throwError $ err500 { errBody = "Invalid authorisation code" }
|
||||
-- return JWT {token = "", tokenType = "JWT", expiration = 0.25 * nominalDay}
|
||||
token <- asks mkToken >>= liftIO
|
||||
|
||||
Loading…
Reference in New Issue
Block a user