345 lines
12 KiB
Haskell
345 lines
12 KiB
Haskell
-- | Internal module exposing the guts of the package. Use at
|
|
-- your own risk. No API stability guarantees apply.
|
|
module Web.ServerSession.Backend.Redis.Internal
|
|
( RedisStorage(..)
|
|
, RedisStorageException(..)
|
|
|
|
, transaction
|
|
, unwrap
|
|
, rSessionKey
|
|
, rAuthKey
|
|
|
|
, RedisSession(..)
|
|
, parseSession
|
|
, printSession
|
|
, parseUTCTime
|
|
, printUTCTime
|
|
, timeFormat
|
|
|
|
, getSessionImpl
|
|
, deleteSessionImpl
|
|
, removeSessionFromAuthId
|
|
, insertSessionForAuthId
|
|
, deleteAllSessionsOfAuthIdImpl
|
|
, insertSessionImpl
|
|
, replaceSessionImpl
|
|
, throwRS
|
|
) where
|
|
|
|
import Control.Applicative as A
|
|
import Control.Arrow (first)
|
|
import Control.Monad (void, when)
|
|
import Control.Monad.IO.Class (liftIO)
|
|
import Data.ByteString (ByteString)
|
|
import Data.List (partition)
|
|
import Data.Maybe (fromMaybe, catMaybes)
|
|
import Data.Proxy (Proxy(..))
|
|
import Data.Typeable (Typeable)
|
|
import Web.PathPieces (toPathPiece)
|
|
import Web.ServerSession.Core
|
|
|
|
import qualified Control.Exception as E
|
|
import qualified Database.Redis as R
|
|
import qualified Data.ByteString as B
|
|
import qualified Data.ByteString.Char8 as B8
|
|
import qualified Data.HashMap.Strict as HM
|
|
import qualified Data.Text.Encoding as TE
|
|
import qualified Data.Time.Clock as TI
|
|
import qualified Data.Time.Clock.POSIX as TP
|
|
import qualified Data.Time.Format as TI
|
|
|
|
#if MIN_VERSION_time(1,5,0)
|
|
import Data.Time.Format (defaultTimeLocale)
|
|
#else
|
|
import System.Locale (defaultTimeLocale)
|
|
#endif
|
|
|
|
----------------------------------------------------------------------
|
|
|
|
|
|
-- | Session storage backend using Redis via the @hedis@ package.
|
|
data RedisStorage sess =
|
|
RedisStorage
|
|
{ connPool :: R.Connection
|
|
-- ^ Connection pool to the Redis server.
|
|
, idleTimeout :: Maybe TI.NominalDiffTime
|
|
-- ^ How long should a session live after last access
|
|
, absoluteTimeout :: Maybe TI.NominalDiffTime
|
|
-- ^ How long should a session live after creation
|
|
} deriving (Typeable)
|
|
|
|
|
|
-- | We do not provide any ACID guarantees for different actions
|
|
-- running inside the same @TransactionM RedisStorage@.
|
|
instance RedisSession sess => Storage (RedisStorage sess) where
|
|
type SessionData (RedisStorage sess) = sess
|
|
type TransactionM (RedisStorage sess) = R.Redis
|
|
runTransactionM = R.runRedis . connPool
|
|
getSession _ = getSessionImpl
|
|
deleteSession _ = deleteSessionImpl
|
|
deleteAllSessionsOfAuthId _ = deleteAllSessionsOfAuthIdImpl
|
|
insertSession = insertSessionImpl
|
|
replaceSession = replaceSessionImpl
|
|
|
|
|
|
-- | An exception thrown by the @serversession-backend-redis@
|
|
-- package.
|
|
data RedisStorageException =
|
|
ExpectedTxSuccess (R.TxResult ())
|
|
-- ^ We expected 'TxSuccess' but got something else.
|
|
| ExpectedRight R.Reply
|
|
-- ^ We expected 'Right' from an @Either 'R.Reply' a@ but got
|
|
-- 'Left'.
|
|
deriving (Show, Typeable)
|
|
|
|
instance E.Exception RedisStorageException
|
|
|
|
|
|
----------------------------------------------------------------------
|
|
|
|
|
|
-- | Run the given Redis transaction and force its result.
|
|
-- Throws a 'RedisStorageException' if the result is not
|
|
-- 'TxSuccess'.
|
|
transaction :: R.RedisTx (R.Queued ()) -> R.Redis ()
|
|
transaction tx = do
|
|
ret <- R.multiExec tx
|
|
case ret of
|
|
R.TxSuccess () -> return ()
|
|
_ -> liftIO $ E.throwIO $ ExpectedTxSuccess ret
|
|
|
|
|
|
-- | Unwraps an @Either 'R.Reply' a@ by throwing an exception if
|
|
-- not @Right@.
|
|
unwrap :: R.Redis (Either R.Reply a) -> R.Redis a
|
|
unwrap act = act >>= either (liftIO . E.throwIO . ExpectedRight) return
|
|
|
|
|
|
-- | Redis key for the given session ID.
|
|
rSessionKey :: SessionId sess -> ByteString
|
|
rSessionKey = B.append "ssr:session:" . TE.encodeUtf8 . toPathPiece
|
|
|
|
|
|
-- | Redis key for the given auth ID.
|
|
rAuthKey :: AuthId -> ByteString
|
|
rAuthKey = B.append "ssr:authid:"
|
|
|
|
|
|
----------------------------------------------------------------------
|
|
|
|
|
|
-- | Class for data types that can be used as session data for
|
|
-- the Redis backend.
|
|
--
|
|
-- It should hold that
|
|
--
|
|
-- @
|
|
-- fromHash p . perm . toHash p === id
|
|
-- @
|
|
--
|
|
-- for all list permutations @perm :: [a] -> [a]@,
|
|
-- where @p :: Proxy sess@.
|
|
class IsSessionData sess => RedisSession sess where
|
|
-- | Transform a decomposed session into a Redis hash. Keys
|
|
-- will be prepended with @\"data:\"@ before being stored.
|
|
toHash :: Proxy sess -> Decomposed sess -> [(ByteString, ByteString)]
|
|
|
|
-- | Parse back a Redis hash into session data.
|
|
fromHash :: Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess
|
|
|
|
|
|
-- | Assumes that keys are UTF-8 encoded when parsing (which is
|
|
-- true if keys are always generated via @toHash@).
|
|
instance RedisSession SessionMap where
|
|
toHash _ = map (first TE.encodeUtf8) . HM.toList . unSessionMap
|
|
fromHash _ = SessionMap . HM.fromList . map (first TE.decodeUtf8)
|
|
|
|
|
|
-- | Parse a 'Session' from a Redis hash.
|
|
parseSession
|
|
:: forall sess. RedisSession sess
|
|
=> SessionId sess
|
|
-> [(ByteString, ByteString)]
|
|
-> Maybe (Session sess)
|
|
parseSession _ [] = Nothing
|
|
parseSession sid bss =
|
|
let (externalList, internalList) = partition (B8.isPrefixOf "data:" . fst) bss
|
|
authId = lookup "internal:authId" internalList
|
|
createdAt = parseUTCTime $ lookup' "internal:createdAt"
|
|
accessedAt = parseUTCTime $ lookup' "internal:accessedAt"
|
|
lookup' k = fromMaybe (error err) $ lookup k internalList
|
|
where err = "serversession-backend-redis/parseSession: missing key " ++ show k
|
|
data_ = fromHash p $ map (first removePrefix) externalList
|
|
where removePrefix bs = let ("data:", key) = B8.splitAt 5 bs in key
|
|
p = Proxy :: Proxy sess
|
|
in Just Session
|
|
{ sessionKey = sid
|
|
, sessionAuthId = authId
|
|
, sessionData = data_
|
|
, sessionCreatedAt = createdAt
|
|
, sessionAccessedAt = accessedAt
|
|
}
|
|
|
|
|
|
-- | Convert a 'Session' into a Redis hash.
|
|
printSession :: forall sess. RedisSession sess => Session sess -> [(ByteString, ByteString)]
|
|
printSession Session {..} =
|
|
maybe id ((:) . (,) "internal:authId") sessionAuthId $
|
|
(:) ("internal:createdAt", printUTCTime sessionCreatedAt) $
|
|
(:) ("internal:accessedAt", printUTCTime sessionAccessedAt) $
|
|
map (first $ B8.append "data:") $
|
|
toHash (Proxy :: Proxy sess) sessionData
|
|
|
|
|
|
-- | Parse 'UTCTime' from a 'ByteString' stored on Redis. Uses
|
|
-- 'error' on parse error.
|
|
parseUTCTime :: ByteString -> TI.UTCTime
|
|
#if MIN_VERSION_time(1,5,0)
|
|
parseUTCTime = TI.parseTimeOrError True defaultTimeLocale timeFormat . B8.unpack
|
|
#else
|
|
parseUTCTime =
|
|
fromMaybe (error "Web.ServerSession.Backend.Redis.Internal.parseUTCTime") .
|
|
TI.parseTime defaultTimeLocale timeFormat . B8.unpack
|
|
#endif
|
|
|
|
|
|
-- | Convert a 'UTCTime' into a 'ByteString' to be stored on
|
|
-- Redis.
|
|
printUTCTime :: TI.UTCTime -> ByteString
|
|
printUTCTime = B8.pack . TI.formatTime defaultTimeLocale timeFormat
|
|
|
|
|
|
-- | Time format used when storing 'UTCTime'.
|
|
timeFormat :: String
|
|
timeFormat = "%Y-%m-%dT%H:%M:%S%Q"
|
|
|
|
|
|
----------------------------------------------------------------------
|
|
|
|
|
|
-- | Run the given Redis command in batches of @511*1024@ items.
|
|
-- This is used for @HMSET@ because there's a hard Redis limit of
|
|
-- @1024*1024@ arguments to a command. The last result is returned.
|
|
batched :: Monad m => ([a] -> m b) -> [a] -> m b
|
|
batched f xs =
|
|
let (this, rest) = splitAt (511*1024) xs
|
|
continue | null rest = return
|
|
| otherwise = const (batched f rest)
|
|
in f this >>= continue
|
|
|
|
|
|
-- | Get the session for the given session ID.
|
|
getSessionImpl :: RedisSession sess => SessionId sess -> R.Redis (Maybe (Session sess))
|
|
getSessionImpl sid = parseSession sid A.<$> unwrap (R.hgetall $ rSessionKey sid)
|
|
|
|
|
|
-- | Delete the session with given session ID.
|
|
deleteSessionImpl :: RedisSession sess => SessionId sess -> R.Redis ()
|
|
deleteSessionImpl sid = do
|
|
msession <- getSessionImpl sid
|
|
case msession of
|
|
Nothing -> return ()
|
|
Just session ->
|
|
transaction $ do
|
|
r <- R.del [rSessionKey sid]
|
|
removeSessionFromAuthId sid (sessionAuthId session)
|
|
return (() <$ r)
|
|
|
|
|
|
-- | Remove the given 'SessionId' from the set of sessions of the
|
|
-- given 'AuthId'. Does not do anything if @Nothing@.
|
|
removeSessionFromAuthId :: (R.RedisCtx m f, Functor m) => SessionId sess -> Maybe AuthId -> m ()
|
|
removeSessionFromAuthId = fooSessionBarAuthId R.srem
|
|
|
|
-- | Insert the given 'SessionId' into the set of sessions of the
|
|
-- given 'AuthId'. Does not do anything if @Nothing@.
|
|
insertSessionForAuthId :: (R.RedisCtx m f, Functor m) => SessionId sess -> Maybe AuthId -> m ()
|
|
insertSessionForAuthId = fooSessionBarAuthId R.sadd
|
|
|
|
|
|
-- | (Internal) Helper for 'removeSessionFromAuthId' and 'insertSessionForAuthId'
|
|
fooSessionBarAuthId
|
|
:: (R.RedisCtx m f, Functor m)
|
|
=> (ByteString -> [ByteString] -> m (f Integer))
|
|
-> SessionId sess
|
|
-> Maybe AuthId
|
|
-> m ()
|
|
fooSessionBarAuthId _ _ Nothing = return ()
|
|
fooSessionBarAuthId fun sid (Just authId) = void $ fun (rAuthKey authId) [rSessionKey sid]
|
|
|
|
|
|
-- | Delete all sessions of the given auth ID.
|
|
deleteAllSessionsOfAuthIdImpl :: AuthId -> R.Redis ()
|
|
deleteAllSessionsOfAuthIdImpl authId = do
|
|
sessionRefs <- unwrap $ R.smembers (rAuthKey authId)
|
|
void $ unwrap $ R.del $ rAuthKey authId : sessionRefs
|
|
|
|
|
|
-- | Insert a new session.
|
|
insertSessionImpl :: RedisSession sess => RedisStorage sess -> Session sess -> R.Redis ()
|
|
insertSessionImpl sto session = do
|
|
-- Check that no old session exists.
|
|
let sid = sessionKey session
|
|
moldSession <- getSessionImpl sid
|
|
case moldSession of
|
|
Just oldSession -> throwRS $ SessionAlreadyExists oldSession session
|
|
Nothing -> do
|
|
transaction $ do
|
|
let sk = rSessionKey sid
|
|
r <- batched (R.hmset sk) (printSession session)
|
|
expireSession session sto
|
|
insertSessionForAuthId (sessionKey session) (sessionAuthId session)
|
|
return (() <$ r)
|
|
|
|
|
|
-- | Replace the contents of a session.
|
|
replaceSessionImpl :: RedisSession sess => RedisStorage sess -> Session sess -> R.Redis ()
|
|
replaceSessionImpl sto session = do
|
|
-- Check that the old session exists.
|
|
let sid = sessionKey session
|
|
moldSession <- getSessionImpl sid
|
|
case moldSession of
|
|
Nothing -> throwRS $ SessionDoesNotExist session
|
|
Just oldSession -> do
|
|
transaction $ do
|
|
-- Delete the old session and set the new one.
|
|
let sk = rSessionKey sid
|
|
_ <- R.del [sk]
|
|
r <- batched (R.hmset sk) (printSession session)
|
|
expireSession session sto
|
|
|
|
-- Remove the old auth ID from the map if it has changed.
|
|
let oldAuthId = sessionAuthId oldSession
|
|
newAuthId = sessionAuthId session
|
|
when (oldAuthId /= newAuthId) $ do
|
|
removeSessionFromAuthId sid oldAuthId
|
|
insertSessionForAuthId sid newAuthId
|
|
|
|
return (() <$ r)
|
|
|
|
|
|
-- | Specialization of 'E.throwIO' for 'RedisStorage'.
|
|
throwRS
|
|
:: Storage (RedisStorage sess)
|
|
=> StorageException (RedisStorage sess)
|
|
-> R.Redis a
|
|
throwRS = liftIO . E.throwIO
|
|
|
|
|
|
-- | Given a session, finds the next time the session will time out,
|
|
-- either by idle or absolute timeout and schedule the key in redis to
|
|
-- expire at that time. This is meant to be used on every write to a
|
|
-- session so that it is constantly setting the appropriate timeout.
|
|
expireSession :: Session sess -> RedisStorage sess -> R.RedisTx ()
|
|
expireSession Session {..} RedisStorage {..} =
|
|
case minimum' (catMaybes [viaIdle, viaAbsolute]) of
|
|
Nothing -> return ()
|
|
Just t -> let ts = round (TP.utcTimeToPOSIXSeconds t)
|
|
in void (R.expireat sk ts)
|
|
where
|
|
sk = rSessionKey sessionKey
|
|
minimum' [] = Nothing
|
|
minimum' xs = Just (minimum xs)
|
|
viaIdle = flip TI.addUTCTime sessionAccessedAt <$> idleTimeout
|
|
viaAbsolute = flip TI.addUTCTime sessionCreatedAt <$> absoluteTimeout
|