diff --git a/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs b/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs index b9de44c..12ae658 100644 --- a/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs +++ b/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs @@ -6,6 +6,7 @@ module Web.ServerSession.Backend.Acid.Internal , ServerSessionAcidState(..) , emptyState , removeSessionFromAuthId + , insertSessionForAuthId , nothingfy , getSession @@ -23,13 +24,13 @@ module Web.ServerSession.Backend.Acid.Internal , AcidStorage(..) ) where -import Control.Monad (when) import Control.Monad.Reader (ask) -import Control.Monad.State (get, modify') +import Control.Monad.State (get, modify', put) import Data.Acid (AcidState, Query, Update, makeAcidic, query, update) import Data.SafeCopy (deriveSafeCopy, base) import Data.Typeable (Typeable) +import qualified Control.Exception as E import qualified Data.Map.Strict as M import qualified Data.Set as S import qualified Web.ServerSession.Core as SS @@ -77,6 +78,13 @@ removeSessionFromAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId removeSessionFromAuthId sid = maybe id (M.update (nothingfy . S.delete sid)) +-- | Insert the given session ID as being part of the given auth +-- ID. Conceptually the opposite of 'removeSessionFromAuthId'. +-- Does not do anything if no 'AuthId' is provided. +insertSessionForAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId -> AuthIdToSessionId +insertSessionForAuthId sid = maybe id (flip (M.insertWith S.union) (S.singleton sid)) + + -- | Change a 'S.Set' to 'Nothing' if it's 'S.null'. nothingfy :: S.Set a -> Maybe (S.Set a) nothingfy s = if S.null s then Nothing else Just s @@ -115,8 +123,10 @@ deleteAllSessionsOfAuthId authId = do -- | Insert a new session. insertSession :: SS.Session -> Update ServerSessionAcidState () insertSession session = do - let insertSess = M.insert sid session - insertAuth = maybe id (flip (M.insertWith S.union) (S.singleton sid)) (SS.sessionAuthId session) + let insertSess s = + let (mold, new) = M.insertLookupWithKey (\_ v _ -> v) sid session s + in maybe new (\old -> E.throw $ SS.SessionAlreadyExists old session) mold + insertAuth = insertSessionForAuthId sid (SS.sessionAuthId session) sid = SS.sessionKey session modify' $ \state -> ServerSessionAcidState @@ -127,16 +137,22 @@ insertSession session = do -- | Replace the contents of a session. replaceSession :: SS.Session -> Update ServerSessionAcidState () replaceSession session = do - -- Remove the old auth ID from the map if it has changed. - let sid = SS.sessionKey session - oldSession <- M.lookup sid . sessionIdToSession <$> get - let oldAuthId = SS.sessionAuthId =<< oldSession - when (oldAuthId /= SS.sessionAuthId session) $ - modify' $ \state -> state - { authIdToSessionId = removeSessionFromAuthId sid oldAuthId $ authIdToSessionId state - } - -- Otherwise the operation is the same as inserting. - insertSession session + -- Check that the old session exists while replacing it. + ServerSessionAcidState sits aits <- get + let (moldSession, sits') = M.updateLookupWithKey (\_ _ -> Just session) sid sits + sid = SS.sessionKey session + case moldSession of + Nothing -> E.throw $ SS.SessionDoesNotExist session + Just oldSession -> do + -- Remove/insert the old auth ID from the map if needed. + let modAits | oldAuthId == newAuthId = id + | otherwise = insertSessionForAuthId sid newAuthId + . removeSessionFromAuthId sid oldAuthId + where oldAuthId = SS.sessionAuthId oldSession + newAuthId = SS.sessionAuthId session + aits' = modAits aits + -- Put modified state in place + put (ServerSessionAcidState sits' aits') ----------------------------------------------------------------------