oauth2-mock-server/src/AuthCode.hs
2024-01-09 02:23:40 +01:00

63 lines
1.8 KiB
Haskell

{-# LANGUAGE OverloadedRecordDot #-}
module AuthCode
( State (..)
, AuthState
, genUnencryptedCode
, verify
) where
import Data.Map.Strict (Map)
import Data.Maybe (isJust)
import Data.Time.Clock
import qualified Data.Map.Strict as M
import Control.Concurrent (forkIO, threadDelay)
import Control.Concurrent.STM.TVar
import Control.Monad (void, (>=>))
import Control.Monad.STM
newtype State = State { activeCodes :: Map String (String, UTCTime) } deriving Show -- ^ maps auth codes to (client ID, expiration time)
type AuthState = TVar State
genUnencryptedCode :: String
-> String
-> NominalDiffTime
-> AuthState
-> IO (Maybe String)
genUnencryptedCode client url expiration state = do
now <- getCurrentTime
let
expiresAt = expiration `addUTCTime` now
simpleCode = filter (/= ' ') $ client <> url <> show now <> show expiresAt
success <- atomically . stateTVar state $ \s ->
let mEntry = M.lookup simpleCode s.activeCodes
in
if isJust mEntry
then (False, s)
else (True, s{ activeCodes = M.insert simpleCode (client, expiresAt) s.activeCodes })
if success then expire simpleCode expiration state >> return (Just simpleCode) else return Nothing
expire :: String -> NominalDiffTime -> AuthState -> IO ()
expire code time state = void . forkIO $ do
threadDelay $ fromEnum time
atomically . modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
verify :: String -> String -> AuthState -> IO Bool
verify code clientID state = do
now <- getCurrentTime
mData <- atomically $ do
result <- (readTVar >=> return . M.lookup code . activeCodes) state
modifyTVar state $ \s -> s{ activeCodes = M.delete code s.activeCodes }
return result
return $ case mData of
Just (clientID', _) -> clientID == clientID'
_ -> False