From 3e33c58af0a514aee4e3f41a3d2efb0a4e53c3cd Mon Sep 17 00:00:00 2001 From: Felipe Lessa Date: Sun, 31 May 2015 11:06:52 -0300 Subject: [PATCH] Generalize session data (huge commit). --- .../serversession-backend-acid-state.cabal | 3 + .../ServerSession/Backend/Acid/Internal.hs | 197 ++++++++-- .../serversession-backend-persistent.cabal | 7 +- .../Web/ServerSession/Backend/Persistent.hs | 6 +- .../Backend/Persistent/Internal/Impl.hs | 280 +++++++++++-- .../Backend/Persistent/Internal/Types.hs | 59 ++- .../tests/Main.hs | 4 +- .../serversession-backend-redis.cabal | 3 + .../src/Web/ServerSession/Backend/Redis.hs | 1 + .../ServerSession/Backend/Redis/Internal.hs | 86 +++- .../serversession-frontend-snap.cabal | 3 + .../src/Web/ServerSession/Frontend/Snap.hs | 1 + .../ServerSession/Frontend/Snap/Internal.hs | 128 ++++-- .../serversession-frontend-wai.cabal | 2 + .../src/Web/ServerSession/Frontend/Wai.hs | 1 + .../ServerSession/Frontend/Wai/Internal.hs | 54 ++- .../serversession-frontend-yesod.cabal | 3 + .../src/Web/ServerSession/Frontend/Yesod.hs | 11 + .../ServerSession/Frontend/Yesod/Internal.hs | 63 ++- serversession/serversession.cabal | 5 + serversession/src/Web/ServerSession/Core.hs | 4 +- .../src/Web/ServerSession/Core/Internal.hs | 368 ++++++++++++------ .../Web/ServerSession/Core/StorageTests.hs | 22 +- serversession/tests/Main.hs | 160 +++++--- 24 files changed, 1125 insertions(+), 346 deletions(-) diff --git a/serversession-backend-acid-state/serversession-backend-acid-state.cabal b/serversession-backend-acid-state/serversession-backend-acid-state.cabal index 0fcba63..f32ffea 100644 --- a/serversession-backend-acid-state/serversession-backend-acid-state.cabal +++ b/serversession-backend-acid-state/serversession-backend-acid-state.cabal @@ -27,9 +27,12 @@ library Web.ServerSession.Backend.Acid Web.ServerSession.Backend.Acid.Internal extensions: + ConstraintKinds DeriveDataTypeable + FlexibleContexts TemplateHaskell TypeFamilies + UndecidableInstances ghc-options: -Wall diff --git a/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs b/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs index 81df8a1..356126d 100644 --- a/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs +++ b/serversession-backend-acid-state/src/Web/ServerSession/Backend/Acid/Internal.hs @@ -26,14 +26,16 @@ module Web.ServerSession.Backend.Acid.Internal import Control.Monad.Reader (ask) import Control.Monad.State (get, modify', put) -import Data.Acid (AcidState, Query, Update, makeAcidic, query, update) -import Data.SafeCopy (deriveSafeCopy, base) +import Data.Acid +import Data.Acid.Advanced +import Data.SafeCopy import Data.Typeable (Typeable) import qualified Control.Exception as E import qualified Data.Map.Strict as M import qualified Data.Set as S import qualified Web.ServerSession.Core as SS +import qualified Web.ServerSession.Core.Internal as SSI ---------------------------------------------------------------------- @@ -41,13 +43,13 @@ import qualified Web.ServerSession.Core as SS -- | Map from session IDs to sessions. The most important map, -- allowing us efficient access to a session given its ID. -type SessionIdToSession = M.Map SS.SessionId SS.Session +type SessionIdToSession sess = M.Map (SS.SessionId sess) (SS.Session sess) -- | Map from auth IDs to session IDs. Allow us to invalidate -- all sessions of given user without having to iterate through -- the whole 'SessionIdToSession' map. -type AuthIdToSessionId = M.Map SS.AuthId (S.Set SS.SessionId) +type AuthIdToSessionId sess = M.Map SS.AuthId (S.Set (SS.SessionId sess)) -- | The current sessions. @@ -55,33 +57,37 @@ type AuthIdToSessionId = M.Map SS.AuthId (S.Set SS.SessionId) -- Besides the obvious map from session IDs to sessions, we also -- maintain a map of auth IDs to session IDs. This allow us to -- quickly invalidate all sessions of a given user. -data ServerSessionAcidState = +data ServerSessionAcidState sess = ServerSessionAcidState - { sessionIdToSession :: !SessionIdToSession - , authIdToSessionId :: !AuthIdToSessionId - } deriving (Show, Typeable) - -deriveSafeCopy 0 'base ''SS.SessionId -- dangerous! -deriveSafeCopy 0 'base ''SS.Session -- dangerous! -deriveSafeCopy 0 'base ''ServerSessionAcidState + { sessionIdToSession :: !(SessionIdToSession sess) + , authIdToSessionId :: !(AuthIdToSessionId sess) + } deriving (Typeable) -- | Empty 'ServerSessionAcidState' used to bootstrap the 'AcidState'. -emptyState :: ServerSessionAcidState +emptyState :: ServerSessionAcidState sess emptyState = ServerSessionAcidState M.empty M.empty -- | Remove the given 'SessionId' from the set of the given -- 'AuthId' on the map. Does not do anything if no 'AuthId' is -- provided. -removeSessionFromAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId -> AuthIdToSessionId +removeSessionFromAuthId + :: SS.SessionId sess + -> Maybe SS.AuthId + -> AuthIdToSessionId sess + -> AuthIdToSessionId sess removeSessionFromAuthId sid = maybe id (M.update (nothingfy . S.delete sid)) -- | Insert the given session ID as being part of the given auth -- ID. Conceptually the opposite of 'removeSessionFromAuthId'. -- Does not do anything if no 'AuthId' is provided. -insertSessionForAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId -> AuthIdToSessionId +insertSessionForAuthId + :: SS.SessionId sess + -> Maybe SS.AuthId + -> AuthIdToSessionId sess + -> AuthIdToSessionId sess insertSessionForAuthId sid = maybe id (flip (M.insertWith S.union) (S.singleton sid)) @@ -93,13 +99,61 @@ nothingfy s = if S.null s then Nothing else Just s ---------------------------------------------------------------------- +deriveSafeCopy 0 'base ''SS.SessionMap + + +-- | We can't @deriveSafeCopy 0 'base ''SS.SessionId@ as +-- otherwise we'd require an unneeded @SafeCopy sess@. +instance SafeCopy (SS.SessionId sess) where + putCopy = contain . safePut . SSI.unS + getCopy = contain $ SSI.S <$> safeGet + + +-- | We can't @deriveSafeCopy 0 'base ''SS.Session@ due to the +-- required context. +instance SafeCopy (SS.Decomposed sess) => SafeCopy (SS.Session sess) where + putCopy (SS.Session key authId data_ createdAt accessedAt) = contain $ do + put_t <- getSafePut + safePut key + safePut authId + safePut data_ + put_t createdAt + put_t accessedAt + getCopy = contain $ do + get_t <- getSafeGet + SS.Session + <$> safeGet + <*> safeGet + <*> safeGet + <*> get_t + <*> get_t + + +-- | We can't @deriveSafeCopy 0 'base ''ServerSessionAcidState@ due +-- to the required context. +instance SafeCopy (SS.Decomposed sess) => SafeCopy (ServerSessionAcidState sess) where + putCopy (ServerSessionAcidState sits aits) = contain $ do + safePut sits + safePut aits + getCopy = contain $ ServerSessionAcidState <$> safeGet <*> safeGet + + +---------------------------------------------------------------------- + + -- | Get the session for the given session ID. -getSession :: SS.SessionId -> Query ServerSessionAcidState (Maybe SS.Session) +getSession + :: SS.Storage (AcidStorage sess) + => SS.SessionId sess + -> Query (ServerSessionAcidState sess) (Maybe (SS.Session sess)) getSession sid = M.lookup sid . sessionIdToSession <$> ask -- | Delete the session with given session ID. -deleteSession :: SS.SessionId -> Update ServerSessionAcidState () +deleteSession + :: SS.Storage (AcidStorage sess) + => SS.SessionId sess + -> Update (ServerSessionAcidState sess) () deleteSession sid = do let removeSession = M.updateLookupWithKey (\_ _ -> Nothing) sid modify' $ \state -> @@ -110,7 +164,10 @@ deleteSession sid = do -- | Delete all sessions of the given auth ID. -deleteAllSessionsOfAuthId :: SS.AuthId -> Update ServerSessionAcidState () +deleteAllSessionsOfAuthId + :: SS.Storage (AcidStorage sess) + => SS.AuthId + -> Update (ServerSessionAcidState sess) () deleteAllSessionsOfAuthId authId = do let removeSession = maybe id (flip M.difference . M.fromSet (const ())) removeAuth = M.updateLookupWithKey (\_ _ -> Nothing) authId @@ -121,11 +178,14 @@ deleteAllSessionsOfAuthId authId = do -- | Insert a new session. -insertSession :: SS.Session -> Update ServerSessionAcidState () +insertSession + :: SS.Storage (AcidStorage sess) + => SS.Session sess + -> Update (ServerSessionAcidState sess) () insertSession session = do let insertSess s = let (mold, new) = M.insertLookupWithKey (\_ v _ -> v) sid session s - in maybe new (\old -> E.throw $ SS.SessionAlreadyExists old session) mold + in maybe new (\old -> throwAS $ SS.SessionAlreadyExists old session) mold insertAuth = insertSessionForAuthId sid (SS.sessionAuthId session) sid = SS.sessionKey session modify' $ \state -> @@ -135,14 +195,17 @@ insertSession session = do -- | Replace the contents of a session. -replaceSession :: SS.Session -> Update ServerSessionAcidState () +replaceSession + :: SS.Storage (AcidStorage sess) + => SS.Session sess + -> Update (ServerSessionAcidState sess) () replaceSession session = do -- Check that the old session exists while replacing it. ServerSessionAcidState sits aits <- get let (moldSession, sits') = M.insertLookupWithKey (\_ v _ -> v) sid session sits sid = SS.sessionKey session case moldSession of - Nothing -> E.throw $ SS.SessionDoesNotExist session + Nothing -> throwAS $ SS.SessionDoesNotExist session Just oldSession -> do -- Remove/insert the old auth ID from the map if needed. let modAits | oldAuthId == newAuthId = id @@ -155,27 +218,103 @@ replaceSession session = do put (ServerSessionAcidState sits' aits') +-- | Specialization of 'E.throw' for 'AcidStorage'. +throwAS + :: SS.Storage (AcidStorage sess) + => SS.StorageException (AcidStorage sess) + -> a +throwAS = E.throw + + ---------------------------------------------------------------------- -makeAcidic ''ServerSessionAcidState ['getSession, 'deleteSession, 'deleteAllSessionsOfAuthId, 'insertSession, 'replaceSession] - - -- | Session storage backend using @acid-state@. -newtype AcidStorage = +newtype AcidStorage sess = AcidStorage - { acidState :: AcidState ServerSessionAcidState + { acidState :: AcidState (ServerSessionAcidState sess) -- ^ Open 'AcidState' of server sessions. } deriving (Typeable) -- | We do not provide any ACID guarantees for different actions -- running inside the same @TransactionM AcidStorage@. -instance SS.Storage AcidStorage where - type TransactionM AcidStorage = IO +instance ( SS.IsSessionData sess + , SafeCopy sess + , SafeCopy (SS.Decomposed sess) + ) => SS.Storage (AcidStorage sess) where + type SessionData (AcidStorage sess) = sess + type TransactionM (AcidStorage sess) = IO runTransactionM = const id getSession (AcidStorage s) = query s . GetSession deleteSession (AcidStorage s) = update s . DeleteSession deleteAllSessionsOfAuthId (AcidStorage s) = update s . DeleteAllSessionsOfAuthId insertSession (AcidStorage s) = update s . InsertSession replaceSession (AcidStorage s) = update s . ReplaceSession + + +---------------------------------------------------------------------- + +-- makeAcidic can't handle type variables, so we have to do +-- everything by hand. :( + +data GetSession sess = GetSession (SS.SessionId sess) +data DeleteSession sess = DeleteSession (SS.SessionId sess) +data DeleteAllSessionsOfAuthId sess = DeleteAllSessionsOfAuthId SS.AuthId +data InsertSession sess = InsertSession (SS.Session sess) +data ReplaceSession sess = ReplaceSession (SS.Session sess) + +instance SafeCopy (GetSession sess) where + putCopy (GetSession v) = contain $ safePut v + getCopy = contain $ GetSession <$> safeGet + +instance SafeCopy (DeleteSession sess) where + putCopy (DeleteSession v) = contain $ safePut v + getCopy = contain $ DeleteSession <$> safeGet + +instance SafeCopy (DeleteAllSessionsOfAuthId sess) where + putCopy (DeleteAllSessionsOfAuthId v) = contain $ safePut v + getCopy = contain $ DeleteAllSessionsOfAuthId <$> safeGet + +instance SafeCopy (SS.Decomposed sess) => SafeCopy (InsertSession sess) where + putCopy (InsertSession v) = contain $ safePut v + getCopy = contain $ InsertSession <$> safeGet + +instance SafeCopy (SS.Decomposed sess) => SafeCopy (ReplaceSession sess) where + putCopy (ReplaceSession v) = contain $ safePut v + getCopy = contain $ ReplaceSession <$> safeGet + +type AcidContext sess = + ( SS.IsSessionData sess + , SafeCopy sess + , SafeCopy (SS.Decomposed sess) ) + +instance AcidContext sess => QueryEvent (GetSession sess) +instance AcidContext sess => UpdateEvent (DeleteSession sess) +instance AcidContext sess => UpdateEvent (DeleteAllSessionsOfAuthId sess) +instance AcidContext sess => UpdateEvent (InsertSession sess) +instance AcidContext sess => UpdateEvent (ReplaceSession sess) + +instance AcidContext sess => Method (GetSession sess) where + type MethodResult (GetSession sess) = Maybe (SS.Session sess) + type MethodState (GetSession sess) = ServerSessionAcidState sess +instance AcidContext sess => Method (DeleteSession sess) where + type MethodResult (DeleteSession sess) = () + type MethodState (DeleteSession sess) = ServerSessionAcidState sess +instance AcidContext sess => Method (DeleteAllSessionsOfAuthId sess) where + type MethodResult (DeleteAllSessionsOfAuthId sess) = () + type MethodState (DeleteAllSessionsOfAuthId sess) = ServerSessionAcidState sess +instance AcidContext sess => Method (InsertSession sess) where + type MethodResult (InsertSession sess) = () + type MethodState (InsertSession sess) = ServerSessionAcidState sess +instance AcidContext sess => Method (ReplaceSession sess) where + type MethodResult (ReplaceSession sess) = () + type MethodState (ReplaceSession sess) = ServerSessionAcidState sess + +instance AcidContext sess => IsAcidic (ServerSessionAcidState sess) where + acidEvents = + [ QueryEvent $ \(GetSession sid) -> getSession sid + , UpdateEvent $ \(DeleteSession sid) -> deleteSession sid + , UpdateEvent $ \(DeleteAllSessionsOfAuthId authId) -> deleteAllSessionsOfAuthId authId + , UpdateEvent $ \(InsertSession session) -> insertSession session + , UpdateEvent $ \(ReplaceSession session) -> replaceSession session ] diff --git a/serversession-backend-persistent/serversession-backend-persistent.cabal b/serversession-backend-persistent/serversession-backend-persistent.cabal index 598c089..fa026c9 100644 --- a/serversession-backend-persistent/serversession-backend-persistent.cabal +++ b/serversession-backend-persistent/serversession-backend-persistent.cabal @@ -24,7 +24,7 @@ library , containers , path-pieces , persistent == 2.1.* - , persistent-template == 2.1.* + , tagged >= 0.8 , text , time , transformers @@ -37,14 +37,19 @@ library extensions: DeriveDataTypeable EmptyDataDecls + FlexibleContexts FlexibleInstances GADTs GeneralizedNewtypeDeriving OverloadedStrings + PatternGuards QuasiQuotes RecordWildCards + ScopedTypeVariables + StandaloneDeriving TemplateHaskell TypeFamilies + UndecidableInstances ghc-options: -Wall diff --git a/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent.hs b/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent.hs index d62db20..c5acd68 100644 --- a/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent.hs +++ b/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent.hs @@ -22,9 +22,11 @@ -- share [mkPersist sqlSettings, mkSave \"entityDefs\"] -- -- -- On Application.hs +-- import Data.Proxy (Proxy(..)) -- tagged package, or base from GHC 7.10 onwards +-- import Web.ServerSession.Core (SessionMap) -- import Web.ServerSession.Backend.Persistent (serverSessionDefs) -- --- mkMigrate \"migrateAll\" (serverSessionDefs ++ entityDefs) +-- mkMigrate \"migrateAll\" (serverSessionDefs (Proxy :: Proxy SessionMap) ++ entityDefs) -- -- makeFoundation = -- ... @@ -32,6 +34,8 @@ -- ... -- @ -- +-- If you're not using @SessionMap@, just change @Proxy@ type above. +-- -- If you forget to setup the migration above, this session -- storage backend will fail at runtime as the required table -- will not exist. diff --git a/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Impl.hs b/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Impl.hs index f328923..f5bc56b 100644 --- a/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Impl.hs +++ b/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Impl.hs @@ -9,68 +9,284 @@ module Web.ServerSession.Backend.Persistent.Internal.Impl , toPersistentSession , fromPersistentSession , SqlStorage(..) + , throwSS ) where import Control.Monad (void) import Control.Monad.IO.Class (liftIO) +import Data.Proxy (Proxy(..)) import Data.Time (UTCTime) import Data.Typeable (Typeable) import Database.Persist (PersistEntity(..)) -import Database.Persist.TH (mkPersist, mkSave, persistLowerCase, share, sqlSettings) +import Web.PathPieces (PathPiece) import Web.ServerSession.Core import qualified Control.Exception as E +import qualified Data.Aeson as A +import qualified Data.Map as M +import qualified Data.Text as T import qualified Database.Persist as P import qualified Database.Persist.Sql as P import Web.ServerSession.Backend.Persistent.Internal.Types -share - [mkPersist sqlSettings, mkSave "serverSessionDefs"] - [persistLowerCase| - PersistentSession json - key SessionId -- Session ID, primary key. - authId ByteStringJ Maybe -- Value of "_ID" session key. - session SessionMapJ -- Rest of the session data. - createdAt UTCTime -- When this session was created. - accessedAt UTCTime -- When this session was last accessed. - Primary key - deriving Eq Ord Show Typeable - |] +-- We can't use the Template Haskell since we want to generalize +-- some fields. +-- +-- This is going to be a pain to upgrade when the next major +-- persistent version comes :(. + +-- | Entity corresponding to a 'Session'. +-- +-- We're bending @persistent@ in ways it wasn't expected to. In +-- particular, this entity is parametrized over the session type. +data PersistentSession sess = + PersistentSession + { persistentSessionKey :: !(SessionId sess) -- ^ Session ID, primary key. + , persistentSessionAuthId :: !(Maybe ByteStringJ) -- ^ Value of "_ID" session key. + , persistentSessionSession :: !(Decomposed sess) -- ^ Rest of the session data. + , persistentSessionCreatedAt :: !UTCTime -- ^ When this session was created. + , persistentSessionAccessedAt :: !UTCTime -- ^ When this session was last accessed. + } deriving (Typeable) + +deriving instance Eq (Decomposed sess) => Eq (PersistentSession sess) +deriving instance Ord (Decomposed sess) => Ord (PersistentSession sess) +deriving instance Show (Decomposed sess) => Show (PersistentSession sess) + + +type PersistentSessionId sess = Key (PersistentSession sess) + +instance forall sess. P.PersistFieldSql (Decomposed sess) => P.PersistEntity (PersistentSession sess) where + type PersistEntityBackend (PersistentSession sess) = P.SqlBackend + + data Unique (PersistentSession sess) + + newtype Key (PersistentSession sess) = + PersistentSessionKey' {unPersistentSessionKey :: SessionId sess} + deriving ( Eq, Ord, Show, Read, PathPiece + , P.PersistField, P.PersistFieldSql, A.ToJSON, A.FromJSON ) + + data EntityField (PersistentSession sess) typ = + typ ~ PersistentSessionId sess => PersistentSessionId + | typ ~ SessionId sess => PersistentSessionKey + | typ ~ Maybe ByteStringJ => PersistentSessionAuthId + | typ ~ Decomposed sess => PersistentSessionSession + | typ ~ UTCTime => PersistentSessionCreatedAt + | typ ~ UTCTime => PersistentSessionAccessedAt + + keyToValues = (:[]) . P.toPersistValue . unPersistentSessionKey + keyFromValues [x] | Right v <- P.fromPersistValue x = Right $ PersistentSessionKey' v + keyFromValues xs = Left $ T.pack $ "PersistentSession/keyFromValues: " ++ show xs + + entityDef _ + = P.EntityDef + (P.HaskellName "PersistentSession") + (P.DBName "persistent_session") + (pfd PersistentSessionId) + ["json"] + [ pfd PersistentSessionKey + , pfd PersistentSessionAuthId + , pfd PersistentSessionSession + , pfd PersistentSessionCreatedAt + , pfd PersistentSessionAccessedAt ] + [] + [] + ["Eq", "Ord", "Show", "Typeable"] + M.empty + False + where + pfd :: P.EntityField (PersistentSession sess) typ -> P.FieldDef + pfd = P.persistFieldDef + + toPersistFields (PersistentSession a b c d e) = + [ P.SomePersistField a + , P.SomePersistField b + , P.SomePersistField c + , P.SomePersistField d + , P.SomePersistField e ] + + fromPersistValues [a, b, c, d, e] = + PersistentSession + <$> err "key" (P.fromPersistValue a) + <*> err "authId" (P.fromPersistValue b) + <*> err "session" (P.fromPersistValue c) + <*> err "createdAt" (P.fromPersistValue d) + <*> err "accessedAt" (P.fromPersistValue e) + where + err :: T.Text -> Either T.Text a -> Either T.Text a + err s (Left r) = Left $ T.concat ["PersistentSession/fromPersistValues/", s, ": ", r] + err _ (Right v) = Right v + fromPersistValues x = Left $ T.pack $ "PersistentSession/fromPersistValues: " ++ show x + + persistUniqueToFieldNames _ = error "Degenerate case, should never happen" + persistUniqueToValues _ = error "Degenerate case, should never happen" + persistUniqueKeys _ = [] + + persistFieldDef PersistentSessionId + = P.FieldDef + (P.HaskellName "Id") + (P.DBName "id") + (P.FTTypeCon + Nothing "PersistentSessionId") + (P.SqlOther "Composite Reference") + [] + True + (P.CompositeRef + (P.CompositeDef + [P.FieldDef + (P.HaskellName "key") + (P.DBName "key") + (P.FTTypeCon Nothing "SessionId") + (P.SqlOther "SqlType unset for key") + [] + True + P.NoReference] + [])) + persistFieldDef PersistentSessionKey + = P.FieldDef + (P.HaskellName "key") + (P.DBName "key") + (P.FTTypeCon Nothing "SessionId sess") + (P.sqlType (Proxy :: Proxy (SessionId sess))) + [] + True + P.NoReference + persistFieldDef PersistentSessionAuthId + = P.FieldDef + (P.HaskellName "authId") + (P.DBName "auth_id") + (P.FTTypeCon Nothing "ByteStringJ") + (P.sqlType (Proxy :: Proxy ByteStringJ)) + ["Maybe"] + True + P.NoReference + persistFieldDef PersistentSessionSession + = P.FieldDef + (P.HaskellName "session") + (P.DBName "session") + (P.FTTypeCon Nothing "Decomposed sess") + (P.sqlType (Proxy :: Proxy (Decomposed sess))) -- Important! + [] + True + P.NoReference + persistFieldDef PersistentSessionCreatedAt + = P.FieldDef + (P.HaskellName "createdAt") + (P.DBName "created_at") + (P.FTTypeCon Nothing "UTCTime") + (P.sqlType (Proxy :: Proxy UTCTime)) + [] + True + P.NoReference + persistFieldDef PersistentSessionAccessedAt + = P.FieldDef + (P.HaskellName "accessedAt") + (P.DBName "accessed_at") + (P.FTTypeCon Nothing "UTCTime") + (P.sqlType (Proxy :: Proxy UTCTime)) + [] + True + P.NoReference + + persistIdField = PersistentSessionId + + fieldLens PersistentSessionId = lensPTH + P.entityKey + (\(P.Entity _ v) k -> P.Entity k v) + fieldLens PersistentSessionKey = lensPTH + (persistentSessionKey . P.entityVal) + (\(P.Entity k v) x -> P.Entity k (v {persistentSessionKey = x})) + fieldLens PersistentSessionAuthId = lensPTH + (persistentSessionAuthId . P.entityVal) + (\(P.Entity k v) x -> P.Entity k (v {persistentSessionAuthId = x})) + fieldLens PersistentSessionSession = lensPTH + (persistentSessionSession . P.entityVal) + (\(P.Entity k v) x -> P.Entity k (v {persistentSessionSession = x})) + fieldLens PersistentSessionCreatedAt = lensPTH + (persistentSessionCreatedAt . P.entityVal) + (\(P.Entity k v) x -> P.Entity k (v {persistentSessionCreatedAt = x})) + fieldLens PersistentSessionAccessedAt = lensPTH + (persistentSessionAccessedAt . P.entityVal) + (\(P.Entity k v) x -> P.Entity k (v {persistentSessionAccessedAt = x})) + + +-- | Copy-paste from @Database.Persist.TH@. Who needs lens anyway... +lensPTH :: Functor f => (s -> a) -> (s -> b -> t) -> (a -> f b) -> s -> f t +lensPTH sa sbt afb s = fmap (sbt s) (afb $ sa s) + + +instance A.ToJSON (Decomposed sess) => A.ToJSON (PersistentSession sess) where + toJSON (PersistentSession key authId session createdAt accessedAt) = + A.object + [ "key" A..= key + , "authId" A..= authId + , "session" A..= session + , "createdAt" A..= createdAt + , "accessedAt" A..= accessedAt ] + +instance A.FromJSON (Decomposed sess) => A.FromJSON (PersistentSession sess) where + parseJSON (A.Object obj) = + PersistentSession + <$> obj A..: "key" + <*> obj A..: "authId" + <*> obj A..: "session" + <*> obj A..: "createdAt" + <*> obj A..: "accessedAt" + parseJSON _ = mempty + +instance ( A.ToJSON (Decomposed sess) + , P.PersistFieldSql (Decomposed sess) + ) => A.ToJSON (P.Entity (PersistentSession sess)) where + toJSON = P.entityIdToJSON + +instance ( A.FromJSON (Decomposed sess) + , P.PersistFieldSql (Decomposed sess) + ) => A.FromJSON (P.Entity (PersistentSession sess)) where + parseJSON = P.entityIdFromJSON + + +-- | Entity definitions needed to generate the SQL schema for +-- 'SqlStorage'. Example using 'SessionMap': +-- +-- @ +-- serverSessionDefs (Proxy :: Proxy SessionMap) +-- @ +serverSessionDefs :: forall sess. PersistEntity (PersistentSession sess) => Proxy sess -> [P.EntityDef] +serverSessionDefs _ = [entityDef (Proxy :: Proxy (PersistentSession sess))] -- | Generate a key to the entity from the session ID. -psKey :: SessionId -> Key PersistentSession +psKey :: SessionId sess -> Key (PersistentSession sess) psKey = PersistentSessionKey' -- | Convert from 'Session' to 'PersistentSession'. -toPersistentSession :: Session -> PersistentSession +toPersistentSession :: Session sess -> PersistentSession sess toPersistentSession Session {..} = PersistentSession { persistentSessionKey = sessionKey , persistentSessionAuthId = fmap B sessionAuthId - , persistentSessionSession = M sessionData + , persistentSessionSession = sessionData , persistentSessionCreatedAt = sessionCreatedAt , persistentSessionAccessedAt = sessionAccessedAt } -- | Convert from 'PersistentSession' to 'Session'. -fromPersistentSession :: PersistentSession -> Session +fromPersistentSession :: PersistentSession sess -> Session sess fromPersistentSession PersistentSession {..} = Session { sessionKey = persistentSessionKey , sessionAuthId = fmap unB persistentSessionAuthId - , sessionData = unM persistentSessionSession + , sessionData = persistentSessionSession , sessionCreatedAt = persistentSessionCreatedAt , sessionAccessedAt = persistentSessionAccessedAt } -- | SQL session storage backend using @persistent@. -newtype SqlStorage = +newtype SqlStorage sess = SqlStorage { connPool :: P.ConnectionPool -- ^ Pool of DB connections. You may use the same pool as @@ -78,22 +294,38 @@ newtype SqlStorage = } deriving (Typeable) -instance Storage SqlStorage where - type TransactionM SqlStorage = P.SqlPersistT IO +instance forall sess. + ( IsSessionData sess + , P.PersistFieldSql (Decomposed sess) + ) => Storage (SqlStorage sess) where + type SessionData (SqlStorage sess) = sess + type TransactionM (SqlStorage sess) = P.SqlPersistT IO runTransactionM = flip P.runSqlPool . connPool getSession _ = fmap (fmap fromPersistentSession) . P.get . psKey deleteSession _ = P.delete . psKey - deleteAllSessionsOfAuthId _ authId = P.deleteWhere [PersistentSessionAuthId P.==. Just (B authId)] + deleteAllSessionsOfAuthId _ authId = + P.deleteWhere [field P.==. Just (B authId)] + where + field :: EntityField (PersistentSession sess) (Maybe ByteStringJ) + field = PersistentSessionAuthId insertSession s session = do mold <- getSession s (sessionKey session) maybe (void $ P.insert $ toPersistentSession session) - (\old -> liftIO $ E.throwIO $ SessionAlreadyExists old session) + (\old -> throwSS $ SessionAlreadyExists old session) mold replaceSession _ session = do let key = psKey $ sessionKey session mold <- P.get key maybe - (liftIO $ E.throwIO $ SessionDoesNotExist session) + (throwSS $ SessionDoesNotExist session) (\_old -> void $ P.replace key $ toPersistentSession session) mold + + +-- | Specialization of 'E.throwIO' for 'SqlStorage'. +throwSS + :: Storage (SqlStorage sess) + => StorageException (SqlStorage sess) + -> TransactionM (SqlStorage sess) a +throwSS = liftIO . E.throwIO diff --git a/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Types.hs b/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Types.hs index 2b4976f..72efb98 100644 --- a/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Types.hs +++ b/serversession-backend-persistent/src/Web/ServerSession/Backend/Persistent/Internal/Types.hs @@ -5,7 +5,11 @@ -- Also exports orphan instances of @PersistField{,Sql} SessionId@. module Web.ServerSession.Backend.Persistent.Internal.Types ( ByteStringJ(..) - , SessionMapJ(..) + -- * Orphan instances + -- ** SessionId + -- $orphanSessionId + -- ** SessionMap + -- $orphanSessionMap ) where import Control.Arrow (first) @@ -29,12 +33,19 @@ import qualified Data.Text.Encoding as TE ---------------------------------------------------------------------- --- | Does not do sanity checks (DB is trusted). -instance PersistField SessionId where +-- $orphanSessionId +-- +-- @ +-- instance 'PersistField' ('SessionId' sess) +-- instance 'PersistFieldSql' ('SessionId' sess) +-- @ +-- +-- Does not do sanity checks (DB is trusted). +instance PersistField (SessionId sess) where toPersistValue = toPersistValue . unS fromPersistValue = fmap S . fromPersistValue -instance PersistFieldSql SessionId where +instance PersistFieldSql (SessionId sess) where sqlType p = sqlType (fmap unS p) @@ -66,31 +77,41 @@ instance A.ToJSON ByteStringJ where ---------------------------------------------------------------------- --- | Newtype of a 'SessionMap' that serializes using @cereal@ on +-- $orphanSessionMap +-- +-- @ +-- instance 'PersistField' 'SessionMap' +-- instance 'PersistFieldSql' 'SessionMap' +-- instance 'S.Serialize' 'SessionMap' +-- instance 'A.FromJSON' 'SessionMap' +-- instance 'A.ToJSON' 'SessionMap' +-- @ + +-- 'PersistField' for 'SessionMap' serializes using @cereal@ on -- the database. We tried to use @aeson@ but @cereal@ is twice -- faster and uses half the memory for this use case. -newtype SessionMapJ = M { unM :: SessionMap } - deriving (Eq, Ord, Show, Read, Typeable) - -instance PersistField SessionMapJ where +-- +-- The JSON instance translates to objects using base64url for +-- the values of 'ByteString' (cf. 'ByteStringJ'). +instance PersistField SessionMap where toPersistValue = toPersistValue . S.encode fromPersistValue = fromPersistValue >=> (either (Left . T.pack) Right . S.decode) -instance PersistFieldSql SessionMapJ where +instance PersistFieldSql SessionMap where sqlType _ = SqlBlob -instance S.Serialize SessionMapJ where - put = S.put . map (first TE.encodeUtf8) . M.toAscList . unM - get = M . M.fromAscList . map (first TE.decodeUtf8) <$> S.get +instance S.Serialize SessionMap where + put = S.put . map (first TE.encodeUtf8) . M.toAscList . unSessionMap + get = SessionMap . M.fromAscList . map (first TE.decodeUtf8) <$> S.get -instance A.FromJSON SessionMapJ where +instance A.FromJSON SessionMap where parseJSON = fmap fixup . A.parseJSON where - fixup :: M.Map Text ByteStringJ -> SessionMapJ - fixup = M . fmap unB + fixup :: M.Map Text ByteStringJ -> SessionMap + fixup = SessionMap . fmap unB -instance A.ToJSON SessionMapJ where +instance A.ToJSON SessionMap where toJSON = A.toJSON . mangle where - mangle :: SessionMapJ -> M.Map Text ByteStringJ - mangle = fmap B . unM + mangle :: SessionMap -> M.Map Text ByteStringJ + mangle = fmap B . unSessionMap diff --git a/serversession-backend-persistent/tests/Main.hs b/serversession-backend-persistent/tests/Main.hs index 09bf574..0de3e36 100644 --- a/serversession-backend-persistent/tests/Main.hs +++ b/serversession-backend-persistent/tests/Main.hs @@ -3,17 +3,19 @@ module Main (main) where import Control.Monad (forM_) import Control.Monad.Logger (runStderrLoggingT, runNoLoggingT) import Data.Pool (destroyAllResources) +import Data.Proxy (Proxy(..)) import Database.Persist.Postgresql (createPostgresqlPool) import Database.Persist.Sqlite (createSqlitePool) import Test.Hspec import Web.ServerSession.Backend.Persistent +import Web.ServerSession.Core (SessionMap) import Web.ServerSession.Core.StorageTests import qualified Control.Exception as E import qualified Database.Persist.TH as P import qualified Database.Persist.Sql as P -P.mkMigrate "migrateAll" serverSessionDefs +P.mkMigrate "migrateAll" (serverSessionDefs (Proxy :: Proxy SessionMap)) main :: IO () main = hspec $ diff --git a/serversession-backend-redis/serversession-backend-redis.cabal b/serversession-backend-redis/serversession-backend-redis.cabal index d69bf6d..1c1dbf2 100644 --- a/serversession-backend-redis/serversession-backend-redis.cabal +++ b/serversession-backend-redis/serversession-backend-redis.cabal @@ -21,6 +21,7 @@ library , containers , hedis == 0.6.* , path-pieces + , tagged >= 0.8 , text , time >= 1.5 , transformers @@ -31,8 +32,10 @@ library Web.ServerSession.Backend.Redis.Internal extensions: DeriveDataTypeable + FlexibleContexts OverloadedStrings RecordWildCards + ScopedTypeVariables TypeFamilies ghc-options: -Wall diff --git a/serversession-backend-redis/src/Web/ServerSession/Backend/Redis.hs b/serversession-backend-redis/src/Web/ServerSession/Backend/Redis.hs index f1a2bcd..f64a4f1 100644 --- a/serversession-backend-redis/src/Web/ServerSession/Backend/Redis.hs +++ b/serversession-backend-redis/src/Web/ServerSession/Backend/Redis.hs @@ -2,6 +2,7 @@ module Web.ServerSession.Backend.Redis ( RedisStorage(..) , RedisStorageException(..) + , RedisSession(..) ) where import Web.ServerSession.Backend.Redis.Internal diff --git a/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs b/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs index c71be27..39dc3cf 100644 --- a/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs +++ b/serversession-backend-redis/src/Web/ServerSession/Backend/Redis/Internal.hs @@ -9,6 +9,7 @@ module Web.ServerSession.Backend.Redis.Internal , rSessionKey , rAuthKey + , RedisSession(..) , parseSession , printSession , parseUTCTime @@ -22,6 +23,7 @@ module Web.ServerSession.Backend.Redis.Internal , deleteAllSessionsOfAuthIdImpl , insertSessionImpl , replaceSessionImpl + , throwRS ) where import Control.Applicative ((<$)) @@ -31,6 +33,7 @@ import Control.Monad.IO.Class (liftIO) import Data.ByteString (ByteString) import Data.List (partition) import Data.Maybe (fromMaybe) +import Data.Proxy (Proxy(..)) import Data.Typeable (Typeable) import Web.PathPieces (toPathPiece) import Web.ServerSession.Core @@ -49,7 +52,7 @@ import qualified Data.Time.Format as TI -- | Session storage backend using Redis via the @hedis@ package. -newtype RedisStorage = +newtype RedisStorage sess = RedisStorage { connPool :: R.Connection -- ^ Connection pool to the Redis server. @@ -58,8 +61,9 @@ newtype RedisStorage = -- | We do not provide any ACID guarantees for different actions -- running inside the same @TransactionM RedisStorage@. -instance Storage RedisStorage where - type TransactionM RedisStorage = R.Redis +instance RedisSession sess => Storage (RedisStorage sess) where + type SessionData (RedisStorage sess) = sess + type TransactionM (RedisStorage sess) = R.Redis runTransactionM = R.runRedis . connPool getSession _ = getSessionImpl deleteSession _ = deleteSessionImpl @@ -102,7 +106,7 @@ unwrap act = act >>= either (liftIO . E.throwIO . ExpectedRight) return -- | Redis key for the given session ID. -rSessionKey :: SessionId -> ByteString +rSessionKey :: SessionId sess -> ByteString rSessionKey = B.append "ssr:session:" . TE.encodeUtf8 . toPathPiece @@ -114,8 +118,39 @@ rAuthKey = B.append "ssr:authid:" ---------------------------------------------------------------------- +-- | Class for data types that can be used as session data for +-- the Redis backend. +-- +-- It should hold that +-- +-- @ +-- fromHash p . perm . toHash p === id +-- @ +-- +-- for all list permutations @perm :: [a] -> [a]@, +-- where @p :: Proxy sess@. +class IsSessionData sess => RedisSession sess where + -- | Transform a decomposed session into a Redis hash. Keys + -- will be prepended with @\"data:\"@ before being stored. + toHash :: Proxy sess -> Decomposed sess -> [(ByteString, ByteString)] + + -- | Parse back a Redis hash into session data. + fromHash :: Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess + + +-- | Assumes that keys are UTF-8 encoded when parsing (which is +-- true if keys are always generated via @toHash@). +instance RedisSession SessionMap where + toHash _ = map (first TE.encodeUtf8) . M.toList . unSessionMap + fromHash _ = SessionMap . M.fromList . map (first TE.decodeUtf8) + + -- | Parse a 'Session' from a Redis hash. -parseSession :: SessionId -> [(ByteString, ByteString)] -> Maybe Session +parseSession + :: forall sess. RedisSession sess + => SessionId sess + -> [(ByteString, ByteString)] + -> Maybe (Session sess) parseSession _ [] = Nothing parseSession sid bss = let (externalList, internalList) = partition (B8.isPrefixOf "data:" . fst) bss @@ -124,25 +159,26 @@ parseSession sid bss = accessedAt = parseUTCTime $ lookup' "internal:accessedAt" lookup' k = fromMaybe (error err) $ lookup k internalList where err = "serversession-backend-redis/parseSession: missing key " ++ show k - sessionMap = M.fromList $ map (first $ TE.decodeUtf8 . removePrefix) externalList + data_ = fromHash p $ map (first removePrefix) externalList where removePrefix bs = let ("data:", key) = B8.splitAt 5 bs in key + p = Proxy :: Proxy sess in Just Session { sessionKey = sid , sessionAuthId = authId - , sessionData = sessionMap + , sessionData = data_ , sessionCreatedAt = createdAt , sessionAccessedAt = accessedAt } -- | Convert a 'Session' into a Redis hash. -printSession :: Session -> [(ByteString, ByteString)] +printSession :: forall sess. RedisSession sess => Session sess -> [(ByteString, ByteString)] printSession Session {..} = maybe id ((:) . (,) "internal:authId") sessionAuthId $ (:) ("internal:createdAt", printUTCTime sessionCreatedAt) $ (:) ("internal:accessedAt", printUTCTime sessionAccessedAt) $ - map (first $ B8.append "data:" . TE.encodeUtf8) $ - M.toList sessionData + map (first $ B8.append "data:") $ + toHash (Proxy :: Proxy sess) sessionData -- | Parse 'UTCTime' from a 'ByteString' stored on Redis. Uses @@ -177,12 +213,12 @@ batched f xs = -- | Get the session for the given session ID. -getSessionImpl :: SessionId -> R.Redis (Maybe Session) +getSessionImpl :: RedisSession sess => SessionId sess -> R.Redis (Maybe (Session sess)) getSessionImpl sid = parseSession sid <$> unwrap (R.hgetall $ rSessionKey sid) -- | Delete the session with given session ID. -deleteSessionImpl :: SessionId -> R.Redis () +deleteSessionImpl :: RedisSession sess => SessionId sess -> R.Redis () deleteSessionImpl sid = do msession <- getSessionImpl sid case msession of @@ -196,18 +232,22 @@ deleteSessionImpl sid = do -- | Remove the given 'SessionId' from the set of sessions of the -- given 'AuthId'. Does not do anything if @Nothing@. -removeSessionFromAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m () +removeSessionFromAuthId :: R.RedisCtx m f => SessionId sess -> Maybe AuthId -> m () removeSessionFromAuthId = fooSessionBarAuthId R.srem -- | Insert the given 'SessionId' into the set of sessions of the -- given 'AuthId'. Does not do anything if @Nothing@. -insertSessionForAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m () +insertSessionForAuthId :: R.RedisCtx m f => SessionId sess -> Maybe AuthId -> m () insertSessionForAuthId = fooSessionBarAuthId R.sadd -- | (Internal) Helper for 'removeSessionFromAuthId' and 'insertSessionForAuthId' fooSessionBarAuthId - :: R.RedisCtx m f => (ByteString -> [ByteString] -> m (f Integer)) -> SessionId -> Maybe AuthId -> m () + :: R.RedisCtx m f + => (ByteString -> [ByteString] -> m (f Integer)) + -> SessionId sess + -> Maybe AuthId + -> m () fooSessionBarAuthId _ _ Nothing = return () fooSessionBarAuthId fun sid (Just authId) = void $ fun (rAuthKey authId) [rSessionKey sid] @@ -220,13 +260,13 @@ deleteAllSessionsOfAuthIdImpl authId = do -- | Insert a new session. -insertSessionImpl :: Session -> R.Redis () +insertSessionImpl :: RedisSession sess => Session sess -> R.Redis () insertSessionImpl session = do -- Check that no old session exists. let sid = sessionKey session moldSession <- getSessionImpl sid case moldSession of - Just oldSession -> liftIO $ E.throwIO $ SessionAlreadyExists oldSession session + Just oldSession -> throwRS $ SessionAlreadyExists oldSession session Nothing -> do transaction $ do let sk = rSessionKey sid @@ -237,13 +277,13 @@ insertSessionImpl session = do -- | Replace the contents of a session. -replaceSessionImpl :: Session -> R.Redis () +replaceSessionImpl :: RedisSession sess => Session sess -> R.Redis () replaceSessionImpl session = do -- Check that the old session exists. let sid = sessionKey session moldSession <- getSessionImpl sid case moldSession of - Nothing -> liftIO $ E.throwIO $ SessionDoesNotExist session + Nothing -> throwRS $ SessionDoesNotExist session Just oldSession -> do transaction $ do -- Delete the old session and set the new one. @@ -259,3 +299,11 @@ replaceSessionImpl session = do insertSessionForAuthId sid newAuthId return (() <$ r) + + +-- | Specialization of 'E.throwIO' for 'RedisStorage'. +throwRS + :: Storage (RedisStorage sess) + => StorageException (RedisStorage sess) + -> R.Redis a +throwRS = liftIO . E.throwIO diff --git a/serversession-frontend-snap/serversession-frontend-snap.cabal b/serversession-frontend-snap/serversession-frontend-snap.cabal index 22431ce..b01ac6e 100644 --- a/serversession-frontend-snap/serversession-frontend-snap.cabal +++ b/serversession-frontend-snap/serversession-frontend-snap.cabal @@ -32,7 +32,10 @@ library Web.ServerSession.Frontend.Snap Web.ServerSession.Frontend.Snap.Internal extensions: + FlexibleContexts OverloadedStrings + TypeFamilies + UndecidableInstances ghc-options: -Wall source-repository head diff --git a/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap.hs b/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap.hs index 06acc91..1de215e 100644 --- a/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap.hs +++ b/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap.hs @@ -3,6 +3,7 @@ module Web.ServerSession.Frontend.Snap ( -- * Using server-side sessions initServerSessionManager , simpleServerSessionManager + , SnapSession(..) -- * Invalidating session IDs , forceInvalidate , ForceInvalidate(..) diff --git a/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap/Internal.hs b/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap/Internal.hs index 9065456..732b673 100644 --- a/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap/Internal.hs +++ b/serversession-frontend-snap/src/Web/ServerSession/Frontend/Snap/Internal.hs @@ -3,6 +3,7 @@ module Web.ServerSession.Frontend.Snap.Internal ( initServerSessionManager , simpleServerSessionManager + , SnapSession(..) , ServerSessionManager(..) , currentSessionMap , modifyCurrentSession @@ -33,7 +34,10 @@ import qualified Snap.Snaplet.Session.SessionManager as S -- | Create a new 'ServerSessionManager' using the given 'State'. -initServerSessionManager :: Storage s => IO (State s) -> S.SnapletInit b S.SessionManager +initServerSessionManager + :: (Storage sto, SnapSession (SessionData sto)) + => IO (State sto) + -> S.SnapletInit b S.SessionManager initServerSessionManager mkState = S.makeSnaplet "ServerSession" "Snaplet providing sessions via server-side storage." @@ -51,17 +55,67 @@ initServerSessionManager mkState = -- | Simplified version of 'initServerSessionManager', sufficient -- for most needs. -simpleServerSessionManager :: Storage s => IO s -> (State s -> State s) -> S.SnapletInit b S.SessionManager +simpleServerSessionManager + :: (Storage sto, SessionData sto ~ SessionMap) + => IO sto + -> (State sto -> State sto) + -> S.SnapletInit b S.SessionManager simpleServerSessionManager mkStorage opts = initServerSessionManager (fmap opts . createState =<< mkStorage) +---------------------------------------------------------------------- + + +-- | Class for data types that implement the operations Snap +-- expects sessions to support. +class IsSessionData sess => SnapSession sess where + ssInsert :: Text -> Text -> sess -> sess + ssLookup :: Text -> sess -> Maybe Text + ssDelete :: Text -> sess -> sess + ssToList :: sess -> [(Text, Text)] + + ssInsertCsrf :: Text -> sess -> sess + ssLookupCsrf :: sess -> Maybe Text + + ssForceInvalidate :: ForceInvalidate -> sess -> sess + + +-- | Uses 'csrfKey'. +instance SnapSession SessionMap where + ssInsert key val = onSM (M.insert key (TE.encodeUtf8 val)) + ssLookup key = fmap TE.decodeUtf8 . M.lookup key . unSessionMap + ssDelete key = onSM (M.delete key) + ssToList = + -- Remove the CSRF key from the list as the current + -- clientsession backend doesn't return it. + fmap (second TE.decodeUtf8) . + M.toList . + M.delete csrfKey . + unSessionMap + + ssInsertCsrf = ssInsert csrfKey + ssLookupCsrf = ssLookup csrfKey + + ssForceInvalidate force = onSM (M.insert forceInvalidateKey (B8.pack $ show force)) + + +-- | Apply a function to a 'SessionMap'. +onSM + :: (M.Map Text ByteString -> M.Map Text ByteString) + -> (SessionMap -> SessionMap) +onSM f = SessionMap . f . unSessionMap + + +---------------------------------------------------------------------- + + -- | A 'S.ISessionManager' using server-side sessions. -data ServerSessionManager s = +data ServerSessionManager sto = ServerSessionManager - { currentSession :: Maybe (SessionMap, SaveSessionToken) + { currentSession :: Maybe (SessionData sto, SaveSessionToken sto) -- ^ Field used for per-request caching of the session. - , state :: State s + , state :: State sto -- ^ The core @serversession@ state. , cookieName :: ByteString -- ^ Cache of the cookie name as bytestring. @@ -70,26 +124,32 @@ data ServerSessionManager s = } deriving (Typeable) -instance Storage s => S.ISessionManager (ServerSessionManager s) where +instance ( Storage sto + , SnapSession (SessionData sto) + ) => S.ISessionManager (ServerSessionManager sto) where + load ssm@ServerSessionManager { currentSession = Just _ } = + -- Don't do anything if already loaded. Yeah, I know this is + -- strange, go figure. + return ssm load ssm = do -- Get session ID from cookie. mcookie <- S.getCookie (cookieName ssm) -- Load session from storage backend. - (sessionMap, saveSessionToken) <- + (data1, saveSessionToken) <- liftIO $ loadSession (state ssm) (S.cookieValue <$> mcookie) -- Add CSRF token if needed. - sessionMap' <- + data2 <- maybe - (flip (M.insert csrfKey) sessionMap <$> N.nonce128url (nonceGen ssm)) - (const $ return sessionMap) - (M.lookup csrfKey sessionMap) + (flip ssInsertCsrf data1 <$> N.nonce128urlT (nonceGen ssm)) + (const $ return data1) + (ssLookupCsrf data1) -- Good to go! - return ssm { currentSession = Just (sessionMap', saveSessionToken) } + return ssm { currentSession = Just (data2, saveSessionToken) } commit ssm = do -- Save session data to storage backend and set the cookie. - let Just (sessionMap, saveSessionToken) = currentSession ssm - msession <- liftIO $ saveSession (state ssm) saveSessionToken sessionMap + let Just (data_, saveSessionToken) = currentSession ssm + msession <- liftIO $ saveSession (state ssm) saveSessionToken data_ S.modifyResponse $ S.addResponseCookie $ maybe (deleteCookie (state ssm) (cookieName ssm)) @@ -100,60 +160,62 @@ instance Storage s => S.ISessionManager (ServerSessionManager s) where -- Reset has no defined semantics. We invalidate the session -- and clear its variables, which seems to be what the -- current clientsession backend from the snap package does. - csrfToken <- N.nonce128url (nonceGen ssm) - let newSession = M.fromList [ (forceInvalidateKey, B8.pack $ show CurrentSessionId) - , (csrfKey, csrfToken) ] + csrfToken <- N.nonce128urlT (nonceGen ssm) + let newSession = + ssInsertCsrf csrfToken $ + ssForceInvalidate CurrentSessionId $ + emptySession return $ modifyCurrentSession (const newSession) ssm touch = -- We always touch the session (if commit is called). id - insert key value = modifyCurrentSession (M.insert key (TE.encodeUtf8 value)) + insert key value = modifyCurrentSession (ssInsert key value) lookup key = -- Decoding will always succeed if the session is used only -- from snap. - fmap TE.decodeUtf8 . M.lookup key . currentSessionMap "lookup" + ssLookup key . currentSessionMap "lookup" - delete key = modifyCurrentSession (M.delete key) + delete key = modifyCurrentSession (ssDelete key) csrf = -- Guaranteed to succeed since both load and reset add a -- csrfKey to the session map. fromMaybe (error "serversession-frontend-snap/csrf: never here") . - S.lookup csrfKey + ssLookupCsrf . currentSessionMap "csrf" - toList = - -- Remove the CSRF key from the list as the current - -- clientsession backend doesn't return it. - fmap (second TE.decodeUtf8) . - M.toList . - M.delete csrfKey . - currentSessionMap "toList" + toList = ssToList . currentSessionMap "toList" --- | Get the current 'SessionMap' from 'currentSession' and +-- | Get the current 'SessionData' from 'currentSession' and -- unwrap its @Just@. If it's @Nothing@, @error@ is called. We -- expect 'load' to be called before any other 'ISessionManager' -- method. -currentSessionMap :: String -> ServerSessionManager s -> SessionMap +currentSessionMap :: String -> ServerSessionManager sto -> SessionData sto currentSessionMap fn ssm = maybe (error err) fst (currentSession ssm) where err = "serversession-frontend-snap/" ++ fn ++ ": currentSession is Nothing, did you call 'load'?" -- | Modify the current session in any way. -modifyCurrentSession :: (SessionMap -> SessionMap) -> ServerSessionManager s -> ServerSessionManager s +modifyCurrentSession + :: (SessionData sto -> SessionData sto) + -> ServerSessionManager sto + -> ServerSessionManager sto modifyCurrentSession f ssm = ssm { currentSession = fmap (first f) (currentSession ssm) } +---------------------------------------------------------------------- + + -- | Create a cookie for the given session. -- -- The cookie expiration is set via 'nextExpires'. Note that -- this is just an optimization, as the expiration is checked on -- the server-side as well. -createCookie :: State s -> ByteString -> Session -> S.Cookie +createCookie :: State sto -> ByteString -> Session sess -> S.Cookie createCookie st cookieNameBS session = -- Generate a cookie with the final session ID. S.Cookie @@ -176,7 +238,7 @@ createCookie st cookieNameBS session = -- * If the user had a session cookie that was invalidated, -- this will remove the invalid cookie from the client. -- the server-side as well. -deleteCookie :: State s -> ByteString -> S.Cookie +deleteCookie :: State sto -> ByteString -> S.Cookie deleteCookie st cookieNameBS = S.Cookie { S.cookieName = cookieNameBS diff --git a/serversession-frontend-wai/serversession-frontend-wai.cabal b/serversession-frontend-wai/serversession-frontend-wai.cabal index d56a5d9..c03d55c 100644 --- a/serversession-frontend-wai/serversession-frontend-wai.cabal +++ b/serversession-frontend-wai/serversession-frontend-wai.cabal @@ -34,7 +34,9 @@ library Web.ServerSession.Frontend.Wai Web.ServerSession.Frontend.Wai.Internal extensions: + FlexibleContexts OverloadedStrings + TypeFamilies ghc-options: -Wall source-repository head diff --git a/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai.hs b/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai.hs index d177814..5161142 100644 --- a/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai.hs +++ b/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai.hs @@ -17,6 +17,7 @@ module Web.ServerSession.Frontend.Wai -- * Flexible interface , sessionStore , createCookieTemplate + , KeyValue(..) -- * State configuration , setCookieName , setAuthKey diff --git a/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai/Internal.hs b/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai/Internal.hs index 1e245d5..ae51708 100644 --- a/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai/Internal.hs +++ b/serversession-frontend-wai/src/Web/ServerSession/Frontend/Wai/Internal.hs @@ -4,6 +4,7 @@ module Web.ServerSession.Frontend.Wai.Internal ( withServerSession , sessionStore , mkSession + , KeyValue(..) , createCookieTemplate , calculateMaxAge , forceInvalidate @@ -34,10 +35,10 @@ import qualified Web.Cookie as C -- that uses 'WS.withSession', 'createState', 'sessionStore', -- 'getCookieName' and 'createCookieTemplate'. withServerSession - :: (MonadIO m, MonadIO n, Storage s) + :: (MonadIO m, MonadIO n, Storage sto, SessionData sto ~ SessionMap) => V.Key (WS.Session m Text ByteString) -- ^ 'V.Vault' key to use when passing the session through. - -> (State s -> State s) -- ^ Set any options on the @serversession@ state. - -> s -- ^ Storage backend. + -> (State sto -> State sto) -- ^ Set any options on the @serversession@ state. + -> sto -- ^ Storage backend. -> n W.Middleware withServerSession key opts storage = liftIO $ do st <- opts <$> createState storage @@ -56,32 +57,55 @@ withServerSession key opts storage = liftIO $ do -- return an empty @ByteString@ when the empty session was not -- saved. sessionStore - :: (MonadIO m, Storage s) - => State s -- ^ @serversession@ state, incl. storage backend. - -> WS.SessionStore m Text ByteString -- ^ @wai-session@ session store. + :: (MonadIO m, Storage sto, KeyValue (SessionData sto)) + => State sto -- ^ @serversession@ state, incl. storage backend. + -> WS.SessionStore m (Key (SessionData sto)) (Value (SessionData sto)) + -- ^ @wai-session@ session store. sessionStore state = \mcookieVal -> do - (sessionMap, saveSessionToken) <- loadSession state mcookieVal - sessionRef <- I.newIORef sessionMap + (data1, saveSessionToken) <- loadSession state mcookieVal + sessionRef <- I.newIORef data1 let save = do - sessionMap' <- I.atomicModifyIORef' sessionRef $ \a -> (a, a) - msession <- saveSession state saveSessionToken sessionMap' + data2 <- I.atomicModifyIORef' sessionRef $ \a -> (a, a) + msession <- saveSession state saveSessionToken data2 return $ maybe "" (TE.encodeUtf8 . toPathPiece . sessionKey) msession return (mkSession sessionRef, save) -- | Build a 'WS.Session' from an 'I.IORef' containing the -- session data. -mkSession :: MonadIO m => I.IORef SessionMap -> WS.Session m Text ByteString +mkSession :: (MonadIO m, KeyValue sess) => I.IORef sess -> WS.Session m (Key sess) (Value sess) mkSession sessionRef = -- We need to use atomicModifyIORef instead of readIORef -- because latter may be reordered (cf. "Memory Model" on -- Data.IORef's documentation). - ( \k -> M.lookup k <$> liftIO (I.atomicModifyIORef' sessionRef $ \a -> (a, a)) - , \k v -> liftIO (I.atomicModifyIORef' sessionRef $ flip (,) () . M.insert k v) + ( \k -> kvLookup k <$> liftIO (I.atomicModifyIORef' sessionRef $ \a -> (a, a)) + , \k v -> liftIO (I.atomicModifyIORef' sessionRef $ flip (,) () . kvInsert k v) ) +---------------------------------------------------------------------- + + +-- | Class for session data types that can be used as key-value +-- stores. +class IsSessionData sess => KeyValue sess where + type Key sess :: * + type Value sess :: * + kvLookup :: Key sess -> sess -> Maybe (Value sess) + kvInsert :: Key sess -> Value sess -> sess -> sess + + +instance KeyValue SessionMap where + type Key SessionMap = Text + type Value SessionMap = ByteString + kvLookup k = M.lookup k . unSessionMap + kvInsert k v (SessionMap m) = SessionMap (M.insert k v m) + + +---------------------------------------------------------------------- + + -- | Create a cookie template given a state. -- -- Since we don't have access to the 'Session', we can't fill the @@ -94,7 +118,7 @@ mkSession sessionRef = -- Instead, we fill only the @Max-age@ field. It works fine for -- modern browsers, but many don't support it and will treat the -- cookie as non-persistent (notably IE 6, 7 and 8). -createCookieTemplate :: State s -> C.SetCookie +createCookieTemplate :: State sto -> C.SetCookie createCookieTemplate state = -- Generate a cookie with the final session ID. def @@ -115,7 +139,7 @@ createCookieTemplate state = -- * If no timeout is defined, the result is 10 years. -- -- * Otherwise, the max age is set as the maximum timeout. -calculateMaxAge :: State s -> Maybe TI.DiffTime +calculateMaxAge :: State sto -> Maybe TI.DiffTime calculateMaxAge st = do guard (persistentCookies st) return $ maybe (60*60*24*3652) realToFrac diff --git a/serversession-frontend-yesod/serversession-frontend-yesod.cabal b/serversession-frontend-yesod/serversession-frontend-yesod.cabal index 960315b..c22c26e 100644 --- a/serversession-frontend-yesod/serversession-frontend-yesod.cabal +++ b/serversession-frontend-yesod/serversession-frontend-yesod.cabal @@ -18,6 +18,7 @@ library build-depends: base == 4.* , bytestring + , containers , cookie >= 0.4 , data-default , path-pieces @@ -32,7 +33,9 @@ library Web.ServerSession.Frontend.Yesod Web.ServerSession.Frontend.Yesod.Internal extensions: + FlexibleContexts OverloadedStrings + TypeFamilies ghc-options: -Wall source-repository head diff --git a/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod.hs b/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod.hs index 887afdb..fe0d859 100644 --- a/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod.hs +++ b/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod.hs @@ -1,8 +1,19 @@ -- | Yesod server-side session support. +-- +-- This package implements an Yesod @SessionBackend@, so it's a +-- drop-in replacement for the default @clientsession@. +-- +-- Unfortunately, Yesod currently provides no way of accessing +-- the session other than via its own functions. If you want to +-- use a custom data type as your session data (instead of the +-- default @SessionMap@), it will have to implement +-- 'IsSessionMap' and you'll have to continue using Yesod's +-- session interface. module Web.ServerSession.Frontend.Yesod ( -- * Using server-side sessions simpleBackend , backend + , IsSessionMap(..) -- * Invalidating session IDs , forceInvalidate , ForceInvalidate(..) diff --git a/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod/Internal.hs b/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod/Internal.hs index 9221b4d..cd5144e 100644 --- a/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod/Internal.hs +++ b/serversession-frontend-yesod/src/Web/ServerSession/Frontend/Yesod/Internal.hs @@ -3,6 +3,7 @@ module Web.ServerSession.Frontend.Yesod.Internal ( simpleBackend , backend + , IsSessionMap(..) , createCookie , findSessionId , forceInvalidate @@ -12,6 +13,7 @@ import Control.Monad (guard) import Control.Monad.IO.Class (MonadIO) import Data.ByteString (ByteString) import Data.Default (def) +import Data.Text (Text) import Web.PathPieces (toPathPiece) import Web.ServerSession.Core import Yesod.Core (MonadHandler) @@ -19,6 +21,7 @@ import Yesod.Core.Handler (setSessionBS) import Yesod.Core.Types (Header(AddCookie), SessionBackend(..)) import qualified Data.ByteString.Char8 as B8 +import qualified Data.Map as M import qualified Data.Text.Encoding as TE import qualified Data.Time as TI import qualified Network.Wai as W @@ -53,42 +56,70 @@ import qualified Web.Cookie as C -- . setAbsoluteTimeout (Just $ 60*60*24) -- . setSecureCookies True -- @ +-- +-- This is a simple version of 'backend' specialized for using +-- 'SessionMap' as 'SessionData'. If you want to use a different +-- session data type, please use 'backend' directly (tip: take a +-- peek at this function's source). simpleBackend - :: (MonadIO m, Storage s) - => (State s -> State s) -- ^ Set any options on the @serversession@ state. - -> s -- ^ Storage backend. + :: (MonadIO m, Storage sto, SessionData sto ~ SessionMap) + => (State sto -> State sto) -- ^ Set any options on the @serversession@ state. + -> sto -- ^ Storage backend. -> m (Maybe SessionBackend) -- ^ Yesod session backend (always @Just@). simpleBackend opts s = return . Just . backend . opts =<< createState s -- | Construct the server-side session backend using the given --- state. +-- state. This is a generalized version of 'simpleBackend'. +-- +-- In order to use the Yesod frontend, you 'SessionData' needs to +-- implement 'IsSessionMap'. backend - :: Storage s - => State s -- ^ @serversession@ state, incl. storage backend. + :: (Storage sto, IsSessionMap (SessionData sto)) + => State sto -- ^ @serversession@ state, incl. storage backend. -> SessionBackend -- ^ Yesod session backend. -backend state = - SessionBackend { - sbLoadSession = \req -> do +backend state = SessionBackend { sbLoadSession = load } + where + load req = do let rawSessionId = findSessionId cookieNameBS req - (sessionMap, saveSessionToken) <- loadSession state rawSessionId + (data_, saveSessionToken) <- loadSession state rawSessionId let save = fmap ((:[]) . maybe (deleteCookie state cookieNameBS) (createCookie state cookieNameBS)) . - saveSession state saveSessionToken - return (sessionMap, save) - } - where + saveSession state saveSessionToken . + fromSessionMap + return (toSessionMap data_, save) + cookieNameBS = TE.encodeUtf8 $ getCookieName state +---------------------------------------------------------------------- + + +-- | Class for session data types meant to be used with the Yesod +-- frontend. The only session interface Yesod provides is via +-- session variables, so your data type needs to be convertible +-- from/to a 'M.Map' of 'Text' to 'ByteString'. +class IsSessionMap sess where + toSessionMap :: sess -> M.Map Text ByteString + fromSessionMap :: M.Map Text ByteString -> sess + + +instance IsSessionMap SessionMap where + toSessionMap = unSessionMap + fromSessionMap = SessionMap + + +---------------------------------------------------------------------- + + -- | Create a cookie for the given session. -- -- The cookie expiration is set via 'nextExpires'. Note that -- this is just an optimization, as the expiration is checked on -- the server-side as well. -createCookie :: State s -> ByteString -> Session -> Header +createCookie :: State sto -> ByteString -> Session sess -> Header createCookie state cookieNameBS session = -- Generate a cookie with the final session ID. AddCookie def @@ -110,7 +141,7 @@ createCookie state cookieNameBS session = -- -- * If the user had a session cookie that was invalidated, -- this will remove the invalid cookie from the client. -deleteCookie :: State s -> ByteString -> Header +deleteCookie :: State sto -> ByteString -> Header deleteCookie state cookieNameBS = AddCookie def { C.setCookieName = cookieNameBS diff --git a/serversession/serversession.cabal b/serversession/serversession.cabal index 74a7d6b..8c9618e 100644 --- a/serversession/serversession.cabal +++ b/serversession/serversession.cabal @@ -37,7 +37,9 @@ library OverloadedStrings RecordWildCards ScopedTypeVariables + StandaloneDeriving TypeFamilies + UndecidableInstances ghc-options: -Wall @@ -53,9 +55,12 @@ test-suite tests , serversession extensions: DeriveDataTypeable + FlexibleContexts OverloadedStrings + StandaloneDeriving TupleSections TypeFamilies + UndecidableInstances main-is: Main.hs ghc-options: -Wall -threaded -with-rtsopts=-N diff --git a/serversession/src/Web/ServerSession/Core.hs b/serversession/src/Web/ServerSession/Core.hs index 34b2253..40a3c9f 100644 --- a/serversession/src/Web/ServerSession/Core.hs +++ b/serversession/src/Web/ServerSession/Core.hs @@ -6,9 +6,11 @@ module Web.ServerSession.Core , Session(..) , Storage(..) , StorageException(..) + , IsSessionData(..) + , DecomposedSession(..) -- * For serversession frontends - , SessionMap + , SessionMap(..) , State , createState , getCookieName diff --git a/serversession/src/Web/ServerSession/Core/Internal.hs b/serversession/src/Web/ServerSession/Core/Internal.hs index 6f5516c..c03a870 100644 --- a/serversession/src/Web/ServerSession/Core/Internal.hs +++ b/serversession/src/Web/ServerSession/Core/Internal.hs @@ -1,13 +1,21 @@ -- | Internal module exposing the guts of the package. Use at -- your own risk. No API stability guarantees apply. +-- +-- @UndecidableInstances@ is required in order to implement @Eq@, +-- @Ord@, @Show@, etc. on data types that have @Decomposed@ +-- fields, and should be fairly safe. module Web.ServerSession.Core.Internal ( SessionId(..) , checkSessionId , generateSessionId - , SessionMap , AuthId , Session(..) + , SessionMap(..) + + , IsSessionData(..) + , DecomposedSession(..) + , Storage(..) , StorageException(..) @@ -32,10 +40,7 @@ module Web.ServerSession.Core.Internal , saveSession , SaveSessionToken(..) , invalidateIfNeeded - , DecomposedSession(..) - , decomposeSession , saveSessionOnDb - , toSessionMap , forceInvalidateKey , ForceInvalidate(..) ) where @@ -64,7 +69,8 @@ import qualified Data.Text.Encoding as TE -- | The ID of a session. Always 18 bytes base64url-encoded as --- 24 characters. +-- 24 characters. The @sess@ type variable is a phantom type for +-- the session data type this session ID points to. -- -- Implementation notes: -- @@ -72,24 +78,24 @@ import qualified Data.Text.Encoding as TE -- -- * Use 'generateSessionId' for securely generating new -- session IDs. -newtype SessionId = S { unS :: Text } +newtype SessionId sess = S { unS :: Text } deriving (Eq, Ord, Show, Read, Typeable) -- | Sanity checks input on 'fromPathPiece' (untrusted input). -instance PathPiece SessionId where +instance PathPiece (SessionId sess) where toPathPiece = unS fromPathPiece = checkSessionId -instance A.FromJSON SessionId where +instance A.FromJSON (SessionId sess) where parseJSON = fmap S . A.parseJSON -instance A.ToJSON SessionId where +instance A.ToJSON (SessionId sess) where toJSON = A.toJSON . unS -- | (Internal) Check that the given text is a base64url-encoded -- representation of 18 bytes. -checkSessionId :: Text -> Maybe SessionId +checkSessionId :: Text -> Maybe (SessionId sess) checkSessionId text = do guard (T.length text == 24) let bs = TE.encodeUtf8 text @@ -99,23 +105,13 @@ checkSessionId text = do -- | Securely generate a new SessionId. -generateSessionId :: N.Generator -> IO SessionId +generateSessionId :: N.Generator -> IO (SessionId sess) generateSessionId = fmap S . N.nonce128urlT ---------------------------------------------------------------------- --- | A session map. --- --- This is the representation of a session used by the --- @serversession@ family of packages, transferring data between --- this core package and frontend packages. Serversession --- storage backend packages should use 'Session'. End users --- should use their web framework's support for sessions. -type SessionMap = M.Map Text ByteString - - -- | Value of the 'authKey' session key. type AuthId = ByteString @@ -124,49 +120,179 @@ type AuthId = ByteString -- -- This representation is used by the @serversession@ family of -- packages, transferring data between this core package and --- storage backend packages. Serversession frontend packages --- should use 'SessionMap'. End users should use their web --- framework's support for sessions. -data Session = +-- storage backend packages. The @sess@ type variable describes +-- the session data type. +data Session sess = Session - { sessionKey :: SessionId + { sessionKey :: SessionId sess -- ^ Session ID, primary key. , sessionAuthId :: Maybe AuthId -- ^ Value of 'authKey' session key, separate from the rest. - , sessionData :: SessionMap + , sessionData :: Decomposed sess -- ^ Rest of the session data. , sessionCreatedAt :: UTCTime -- ^ When this session was created. , sessionAccessedAt :: UTCTime -- ^ When this session was last accessed. - } deriving (Eq, Ord, Show, Typeable) + } deriving (Typeable) + +deriving instance Eq (Decomposed sess) => Eq (Session sess) +deriving instance Ord (Decomposed sess) => Ord (Session sess) +deriving instance Show (Decomposed sess) => Show (Session sess) --- | A storage backend for server-side sessions. -class MonadIO (TransactionM s) => Storage s where +-- | A @newtype@ for a common session map. +-- +-- This is a common representation of a session. Although +-- @serversession@ has generalized session data types, you can +-- use this one if you don't want to worry about it. We strive +-- to support this session data type on all frontends and storage +-- backends. +newtype SessionMap = + SessionMap { unSessionMap :: M.Map Text ByteString } + deriving (Eq, Ord, Show, Read, Typeable) + + +---------------------------------------------------------------------- + + +-- | Class for data types to be used as session data +-- (cf. 'sessionData', 'SessionData'). +-- +-- The @Show@ constrain is needed for 'StorageException'. +class ( Show (Decomposed sess) + , Typeable (Decomposed sess) + , Typeable sess + ) => IsSessionData sess where + -- | The type of the session data after being decomposed. This + -- may be the same as @sess@. + type Decomposed sess :: * + + -- | Empty session data. + emptySession :: sess + + -- | Decompose session data into: + -- + -- * The auth ID of the logged in user (cf. 'setAuthKey', + -- 'dsAuthId'). + -- + -- * If the session is being forced to be invalidated + -- (cf. 'forceInvalidateKey', 'ForceInvalidate'). + -- + -- * The rest of the session data (cf. 'Decomposed'). + decomposeSession + :: Text -- ^ The auth key (cf. 'setAuthKey'). + -> sess -- ^ Session data to be decomposed. + -> DecomposedSession sess -- ^ Decomposed session data. + + -- | Recompose a decomposed session again into a proper @sess@. + recomposeSession + :: Text -- ^ The auth key (cf. 'setAuthKey'). + -> Maybe AuthId -- ^ The @AuthId@, if any. + -> Decomposed sess -- ^ Decomposed session data to be recomposed. + -> sess -- ^ Recomposed session data. + + -- | Returns @True@ when both session datas are to be + -- considered the same. + -- + -- This is used to optimize storage calls + -- (cf. 'setTimeoutResolution'). Always returning @False@ will + -- disable the optimization but won't have any other adverse + -- effects. + -- + -- For data types implementing 'Eq', this is usually a good + -- implementation: + -- + -- @ + -- isSameDecomposed _ = (==) + -- @ + isSameDecomposed :: proxy sess -> Decomposed sess -> Decomposed sess -> Bool + + -- | Returns @True@ if the decomposed session data is to be + -- considered @empty@. + -- + -- This is used to avoid storing empty session data if at all + -- possible. Always returning @False@ will disable the + -- optimization but won't have any other adverse effects. + isDecomposedEmpty :: proxy sess -> Decomposed sess -> Bool + + +-- | A 'SessionMap' decomposes into a 'SessionMap' minus the keys +-- that were removed. The auth key is added back when +-- recomposing. +instance IsSessionData SessionMap where + type Decomposed SessionMap = SessionMap + + emptySession = SessionMap M.empty + + isSameDecomposed _ = (==) + + decomposeSession authKey_ (SessionMap sm1) = + let (authId, sm2) = M.updateLookupWithKey (\_ _ -> Nothing) authKey_ sm1 + (force, sm3) = M.updateLookupWithKey (\_ _ -> Nothing) forceInvalidateKey sm2 + in DecomposedSession + { dsAuthId = authId + , dsForceInvalidate = maybe DoNotForceInvalidate (read . B8.unpack) force + , dsDecomposed = SessionMap sm3 } + + recomposeSession authKey_ mauthId (SessionMap sm) = + SessionMap $ maybe id (M.insert authKey_) mauthId sm + + isDecomposedEmpty _ = M.null . unSessionMap + + +-- | A session data type @sess@ with its special variables taken apart. +data DecomposedSession sess = + DecomposedSession + { dsAuthId :: !(Maybe ByteString) + , dsForceInvalidate :: !ForceInvalidate + , dsDecomposed :: !(Decomposed sess) + } deriving (Typeable) + +deriving instance Eq (Decomposed sess) => Eq (DecomposedSession sess) +deriving instance Ord (Decomposed sess) => Ord (DecomposedSession sess) +deriving instance Show (Decomposed sess) => Show (DecomposedSession sess) + + +---------------------------------------------------------------------- + + +-- | A storage backend @sto@ for server-side sessions. The +-- @sess@ session data type and\/or its 'Decomposed' version may +-- be constrained depending on the storage backend capabilities. +class ( Typeable sto + , MonadIO (TransactionM sto) + , IsSessionData (SessionData sto) + ) => Storage sto where + -- | The session data type used by this storage. + type SessionData sto :: * + -- | Monad where transactions happen for this backend. -- We do not require transactions to be ACID. - type TransactionM s :: * -> * + type TransactionM sto :: * -> * -- | Run a transaction on the IO monad. - runTransactionM :: s -> TransactionM s a -> IO a + runTransactionM :: sto -> TransactionM sto a -> IO a -- | Get the session for the given session ID. Returns -- @Nothing@ if the session is not found. - getSession :: s -> SessionId -> TransactionM s (Maybe Session) + getSession + :: sto + -> SessionId (SessionData sto) + -> TransactionM sto (Maybe (Session (SessionData sto))) -- | Delete the session with given session ID. Does not do -- anything if the session is not found. - deleteSession :: s -> SessionId -> TransactionM s () + deleteSession :: sto -> SessionId (SessionData sto) -> TransactionM sto () -- | Delete all sessions of the given auth ID. Does not do -- anything if there are no sessions of the given auth ID. - deleteAllSessionsOfAuthId :: s -> AuthId -> TransactionM s () + deleteAllSessionsOfAuthId :: sto -> AuthId -> TransactionM sto () -- | Insert a new session. Throws 'SessionAlreadyExists' if -- there already exists a session with the same session ID (we -- only call this method after generating a fresh session ID). - insertSession :: s -> Session -> TransactionM s () + insertSession :: sto -> Session (SessionData sto) -> TransactionM sto () -- | Replace the contents of a session. Throws -- 'SessionDoesNotExist' if there is no session with the given @@ -211,29 +337,32 @@ class MonadIO (TransactionM s) => Storage s where -- Most of the time this discussion does not matter. -- Invalidations usually occur at times where only one request -- is flying. - replaceSession :: s -> Session -> TransactionM s () + replaceSession :: sto -> Session (SessionData sto) -> TransactionM sto () -- | Common exceptions that may be thrown by any storage. -data StorageException = +data StorageException sto = -- | Exception thrown by 'insertSession' whenever a session -- with same ID already exists. SessionAlreadyExists - { seExistingSession :: Session - , seNewSession :: Session } + { seExistingSession :: Session (SessionData sto) + , seNewSession :: Session (SessionData sto) } -- | Exception thrown by 'replaceSession' whenever trying to -- replace a session that is not present on the storage. | SessionDoesNotExist - { seNewSession :: Session } - deriving (Show, Typeable) + { seNewSession :: Session (SessionData sto) } + deriving (Typeable) -instance E.Exception StorageException where +deriving instance Eq (Decomposed (SessionData sto)) => Eq (StorageException sto) +deriving instance Ord (Decomposed (SessionData sto)) => Ord (StorageException sto) +deriving instance Show (Decomposed (SessionData sto)) => Show (StorageException sto) + +instance Storage sto => E.Exception (StorageException sto) where ---------------------------------------------------------------------- - -- TODO: delete expired sessions. -- | The server-side session backend needs to maintain some state @@ -256,10 +385,10 @@ instance E.Exception StorageException where -- and/or secure ('setSecureCookies'). -- -- Create a new 'State' using 'createState'. -data State s = +data State sto = State { generator :: !N.Generator - , storage :: !s + , storage :: !sto , cookieName :: !Text , authKey :: !Text , idleTimeout :: !(Maybe NominalDiffTime) @@ -273,7 +402,7 @@ data State s = -- | Create a new 'State' for the server-side session backend -- using the given storage backend. -createState :: MonadIO m => s -> m (State s) +createState :: MonadIO m => sto -> m (State sto) createState sto = do gen <- N.new return State @@ -294,13 +423,21 @@ createState sto = do -- Defaults to \"JSESSIONID\", which is a generic cookie name -- used by many frameworks thus making it harder to fingerprint -- this implementation. -setCookieName :: Text -> State s -> State s +setCookieName :: Text -> State sto -> State sto setCookieName val state = state { cookieName = val } -- | Set the name of the session variable that keeps track of the --- logged user. Defaults to \"_ID\" (used by @yesod-auth@). -setAuthKey :: Text -> State s -> State s +-- logged user. +-- +-- This setting is used by session data types that are +-- @Map@-alike, using a @lookup@ function. However, the +-- 'IsSessionData' instance of a session data type may choose not +-- to use it. For example, if you implemented a custom data +-- type, you could return the @AuthId@ without needing a lookup. +-- +-- Defaults to \"_ID\" (used by @yesod-auth@). +setAuthKey :: Text -> State sto -> State sto setAuthKey val state = state { authKey = val } @@ -318,7 +455,7 @@ setAuthKey val state = state { authKey = val } -- () -- -- Defaults to 7 days. -setIdleTimeout :: Maybe NominalDiffTime -> State s -> State s +setIdleTimeout :: Maybe NominalDiffTime -> State sto -> State sto setIdleTimeout (Just d) _ | d <= 0 = error "serversession/setIdleTimeout: Timeout should be positive." setIdleTimeout val state = state { idleTimeout = val } @@ -338,7 +475,7 @@ setIdleTimeout val state = state { idleTimeout = val } -- () -- -- Defaults to 60 days. -setAbsoluteTimeout :: Maybe NominalDiffTime -> State s -> State s +setAbsoluteTimeout :: Maybe NominalDiffTime -> State sto -> State sto setAbsoluteTimeout (Just d) _ | d <= 0 = error "serversession/setAbsoluteTimeout: Timeout should be positive." setAbsoluteTimeout val state = state { absoluteTimeout = val } @@ -368,7 +505,7 @@ setAbsoluteTimeout val state = state { absoluteTimeout = val } -- becomes disabled and the session will always be updated. -- -- Defaults to 10 minutes. -setTimeoutResolution :: Maybe NominalDiffTime -> State s -> State s +setTimeoutResolution :: Maybe NominalDiffTime -> State sto -> State sto setTimeoutResolution (Just d) _ | d <= 0 = error "serversession/setTimeoutResolution: Resolution should be positive." setTimeoutResolution val state = state { timeoutResolution = val } @@ -382,7 +519,7 @@ setTimeoutResolution val state = state { timeoutResolution = val } -- cookie is set to expire in 10 years. -- -- Defaults to @True@. -setPersistentCookies :: Bool -> State s -> State s +setPersistentCookies :: Bool -> State sto -> State sto setPersistentCookies val state = state { persistentCookies = val } @@ -393,7 +530,7 @@ setPersistentCookies val state = state { persistentCookies = val } -- It's highly recommended to set this attribute to @True@. -- -- Defaults to @True@. -setHttpOnlyCookies :: Bool -> State s -> State s +setHttpOnlyCookies :: Bool -> State sto -> State sto setHttpOnlyCookies val state = state { httpOnlyCookies = val } @@ -405,22 +542,22 @@ setHttpOnlyCookies val state = state { httpOnlyCookies = val } -- @False@. -- -- Defaults to @False@. -setSecureCookies :: Bool -> State s -> State s +setSecureCookies :: Bool -> State sto -> State sto setSecureCookies val state = state { secureCookies = val } -- | Cf. 'setCookieName'. -getCookieName :: State s -> Text +getCookieName :: State sto -> Text getCookieName = cookieName -- | Cf. 'setHttpOnlyCookies'. -getHttpOnlyCookies :: State s -> Bool +getHttpOnlyCookies :: State sto -> Bool getHttpOnlyCookies = httpOnlyCookies -- | Cf. 'setSecureCookies'. -getSecureCookies :: State s -> Bool +getSecureCookies :: State sto -> Bool getSecureCookies = secureCookies @@ -432,25 +569,33 @@ getSecureCookies = secureCookies -- -- Returns: -- --- * The 'SessionMap' to be used by the frontend as the current --- session's value. +-- * The session data @sess@ to be used by the frontend as the +-- current session's value. -- -- * Information to be passed back to 'saveSession' on the end -- of the request in order to save the session. -loadSession :: Storage s => State s -> Maybe ByteString -> IO (SessionMap, SaveSessionToken) +loadSession + :: Storage sto + => State sto + -> Maybe ByteString + -> IO (SessionData sto, SaveSessionToken sto) loadSession state mcookieVal = do now <- getCurrentTime let maybeInputId = mcookieVal >>= fromPathPiece . TE.decodeUtf8 get = runTransactionM (storage state) . getSession (storage state) checkedGet = fmap (>>= checkExpired now state) . get maybeInput <- maybe (return Nothing) checkedGet maybeInputId - let inputSessionMap = maybe M.empty (toSessionMap state) maybeInput - return (inputSessionMap, SaveSessionToken maybeInput now) + let inputData = + maybe + emptySession + (\s -> recomposeSession (authKey state) (sessionAuthId s) (sessionData s)) + maybeInput + return (inputData, SaveSessionToken maybeInput now) -- | Check if a session @s@ has expired. Returns the @Just s@ if -- not expired, or @Nothing@ if expired. -checkExpired :: UTCTime {-^ Now. -} -> State s -> Session -> Maybe Session +checkExpired :: UTCTime {-^ Now. -} -> State sto -> Session sess -> Maybe (Session sess) checkExpired now state session = let expired = maybe False (< now) (nextExpires state session) in guard (not expired) >> return session @@ -460,7 +605,7 @@ checkExpired now state session = -- will expire assuming that it sees no activity until then. -- Returns @Nothing@ iff the state does not have any expirations -- set to @Just@. -nextExpires :: State s -> Session -> Maybe UTCTime +nextExpires :: State sto -> Session sess -> Maybe UTCTime nextExpires State {..} Session {..} = let viaIdle = flip addUTCTime sessionAccessedAt <$> idleTimeout viaAbsolute = flip addUTCTime sessionCreatedAt <$> absoluteTimeout @@ -471,7 +616,7 @@ nextExpires State {..} Session {..} = -- | Calculate the date that should be used for the cookie's -- \"Expires\" field. -cookieExpires :: State s -> Session -> Maybe UTCTime +cookieExpires :: State sto -> Session sess -> Maybe UTCTime cookieExpires State {..} _ | not persistentCookies = Nothing cookieExpires state session = Just $ fromMaybe tenYearsFromNow $ nextExpires state session @@ -481,9 +626,13 @@ cookieExpires state session = -- | Opaque token containing the necessary information for -- 'saveSession' to save the session. -data SaveSessionToken = - SaveSessionToken (Maybe Session) UTCTime - deriving (Eq, Show, Typeable) +data SaveSessionToken sto = + SaveSessionToken (Maybe (Session (SessionData sto))) UTCTime + deriving (Typeable) + +deriving instance Eq (Decomposed (SessionData sto)) => Eq (SaveSessionToken sto) +deriving instance Ord (Decomposed (SessionData sto)) => Ord (SaveSessionToken sto) +deriving instance Show (Decomposed (SessionData sto)) => Show (SaveSessionToken sto) -- | Save the session on the storage backend. A @@ -496,12 +645,17 @@ data SaveSessionToken = -- and clear every other sesssion variable, then 'saveSession' -- will invalidate the older session but will avoid creating a -- new, empty one. -saveSession :: Storage s => State s -> SaveSessionToken -> SessionMap -> IO (Maybe Session) -saveSession state (SaveSessionToken maybeInput now) wholeOutputSessionMap = +saveSession + :: Storage sto + => State sto + -> SaveSessionToken sto + -> SessionData sto + -> IO (Maybe (Session (SessionData sto))) +saveSession state (SaveSessionToken maybeInput now) outputData = runTransactionM (storage state) $ do - let decomposedSessionMap = decomposeSession state wholeOutputSessionMap - newMaybeInput <- invalidateIfNeeded state maybeInput decomposedSessionMap - saveSessionOnDb state now newMaybeInput decomposedSessionMap + let outputDecomp = decomposeSession (authKey state) outputData + newMaybeInput <- invalidateIfNeeded state maybeInput outputDecomp + saveSessionOnDb state now newMaybeInput outputDecomp -- | Invalidates an old session ID if needed. Returns the @@ -512,11 +666,11 @@ saveSession state (SaveSessionToken maybeInput now) wholeOutputSessionMap = -- fixation attacks. We also invalidate when asked to via -- 'forceInvalidate'. invalidateIfNeeded - :: Storage s - => State s - -> Maybe Session - -> DecomposedSession - -> TransactionM s (Maybe Session) + :: Storage sto + => State sto + -> Maybe (Session (SessionData sto)) + -> DecomposedSession (SessionData sto) + -> TransactionM sto (Maybe (Session (SessionData sto))) invalidateIfNeeded state maybeInput DecomposedSession {..} = do -- Decide which action to take. -- "invalidateOthers implies invalidateCurrent" should be true below. @@ -531,26 +685,6 @@ invalidateIfNeeded state maybeInput DecomposedSession {..} = do return $ guard (not invalidateCurrent) >> maybeInput --- | A 'SessionMap' with its special variables taken apart. -data DecomposedSession = - DecomposedSession - { dsAuthId :: !(Maybe ByteString) - , dsForceInvalidate :: !ForceInvalidate - , dsSessionMap :: !SessionMap - } deriving (Eq, Show, Typeable) - - --- | Decompose a session (see 'DecomposedSession'). -decomposeSession :: State s -> SessionMap -> DecomposedSession -decomposeSession state sm1 = - let (authId, sm2) = M.updateLookupWithKey (\_ _ -> Nothing) (authKey state) sm1 - (force, sm3) = M.updateLookupWithKey (\_ _ -> Nothing) forceInvalidateKey sm2 - in DecomposedSession - { dsAuthId = authId - , dsForceInvalidate = maybe DoNotForceInvalidate (read . B8.unpack) force - , dsSessionMap = sm3 } - - -- | Save a session on the database. If an old session is -- supplied, it is replaced, otherwise a new session is -- generated. If the session is empty, it is not saved and @@ -558,24 +692,30 @@ decomposeSession state sm1 = -- is applied (cf. 'setTimeoutResolution'), the old session is -- returned and no update is made. saveSessionOnDb - :: Storage s - => State s - -> UTCTime -- ^ Now. - -> Maybe Session -- ^ The old session, if any. - -> DecomposedSession -- ^ The session data to be saved. - -> TransactionM s (Maybe Session) -- ^ Copy of saved session. + :: forall sto. Storage sto + => State sto + -> UTCTime -- ^ Now. + -> Maybe (Session (SessionData sto)) -- ^ The old session, if any. + -> DecomposedSession (SessionData sto) -- ^ The session data to be saved. + -> TransactionM sto (Maybe (Session (SessionData sto))) -- ^ Copy of saved session. saveSessionOnDb _ _ Nothing (DecomposedSession Nothing _ m) -- Return Nothing without doing anything whenever the session -- is empty (including auth ID) and there was no prior session. - | M.null m = return Nothing -saveSessionOnDb State { timeoutResolution = Just res } now (Just old) (DecomposedSession authId _ sessionMap) + | isDecomposedEmpty proxy m = return Nothing + where + proxy :: Maybe (SessionData sto) + proxy = Nothing +saveSessionOnDb State { timeoutResolution = Just res } now (Just old) (DecomposedSession authId _ newSession) -- If the data is the same and the old access time is within -- the timeout resolution, just return the old session without -- doing anything else. - | sessionData old == sessionMap && - sessionAuthId old == authId && + | sessionAuthId old == authId && + isSameDecomposed proxy (sessionData old) newSession && abs (diffUTCTime now (sessionAccessedAt old)) < res = return (Just old) + where + proxy :: Maybe (SessionData sto) + proxy = Nothing saveSessionOnDb state now maybeInput DecomposedSession {..} = do -- Generate properties if needed or take them from previous -- saved session. @@ -593,7 +733,7 @@ saveSessionOnDb state now maybeInput DecomposedSession {..} = do let session = Session { sessionKey = key , sessionAuthId = dsAuthId - , sessionData = dsSessionMap + , sessionData = dsDecomposed , sessionCreatedAt = createdAt , sessionAccessedAt = now } @@ -601,12 +741,6 @@ saveSessionOnDb state now maybeInput DecomposedSession {..} = do return (Just session) --- | Create a 'SessionMap' from a 'Session'. -toSessionMap :: State s -> Session -> SessionMap -toSessionMap state Session {..} = - maybe id (M.insert $ authKey state) sessionAuthId sessionData - - -- | The session key used to signal that the session ID should be -- invalidated. forceInvalidateKey :: Text diff --git a/serversession/src/Web/ServerSession/Core/StorageTests.hs b/serversession/src/Web/ServerSession/Core/StorageTests.hs index 0eeb940..e91ec66 100644 --- a/serversession/src/Web/ServerSession/Core/StorageTests.hs +++ b/serversession/src/Web/ServerSession/Core/StorageTests.hs @@ -18,7 +18,7 @@ import qualified Data.Text as T import qualified Data.Time as TI --- | Execute all storage tests. +-- | Execute all storage tests using 'SessionMap'. -- -- This function is meant to be used with @hspec@. However, we -- don't want to depend on @hspec@, so it takes the relevant @@ -42,8 +42,8 @@ import qualified Data.Time as TI -- sequentially in order to reduce the peak memory usage of the -- test suite. allStorageTests - :: forall m s. (Monad m, Storage s) - => s -- ^ Storage backend. + :: forall m sto. (Monad m, Storage sto, SessionData sto ~ SessionMap) + => sto -- ^ Storage backend. -> (String -> IO () -> m ()) -- ^ @hspec@'s it. -> (forall a. IO a -> m a) -- ^ @hspec@'s runIO. -> (m () -> m ()) -- ^ @hspec@'s parallel @@ -52,7 +52,7 @@ allStorageTests -> (forall a e. Exception e => IO a -> (e -> Bool) -> IO ()) -- ^ @hspec@'s shouldThrow. -> m () allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = do - let run :: forall a. TransactionM s a -> IO a + let run :: forall a. TransactionM sto a -> IO a run = runTransactionM storage gen <- runIO N.new @@ -131,7 +131,8 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d run (insertSession storage s1) run (getSession storage sid) `shouldReturn` Just s1 run (insertSession storage s3) `shouldThrow` - (\(SessionAlreadyExists s1' s3') -> s1 == s1' && s3 == s3') + (\(SessionAlreadyExists s1' s3' :: StorageException sto) -> + s1 == s1' && s3 == s3') run (getSession storage sid) `shouldReturn` Just s1 -- replaceSession @@ -153,7 +154,8 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d s <- generateSession gen HasAuthId let sid = sessionKey s run (getSession storage sid) `shouldReturn` Nothing - run (replaceSession storage s) `shouldThrow` (\(SessionDoesNotExist s') -> s == s') + run (replaceSession storage s) `shouldThrow` + (\(SessionDoesNotExist s' :: StorageException sto) -> s == s') run (getSession storage sid) `shouldReturn` Nothing run (insertSession storage s) run (getSession storage sid) `shouldReturn` Just s @@ -169,11 +171,11 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d let session = Session { sessionKey = sid , sessionAuthId = Nothing - , sessionData = M.fromList vals + , sessionData = SessionMap $ M.fromList vals , sessionCreatedAt = now , sessionAccessedAt = now } - ver2 = session { sessionData = M.empty } + ver2 = session { sessionData = SessionMap M.empty } run (getSession storage sid) `shouldReturn` Nothing run (insertSession storage session) run (getSession storage sid) `shouldReturn` (Just session) @@ -204,7 +206,7 @@ generateAuthId = N.nonce128url -- | Generate a random session for our tests. -generateSession :: N.Generator -> HasAuthId -> IO Session +generateSession :: N.Generator -> HasAuthId -> IO (Session SessionMap) generateSession gen hasAuthId = do sid <- generateSessionId gen authId <- @@ -219,7 +221,7 @@ generateSession gen hasAuthId = do return Session { sessionKey = sid , sessionAuthId = authId - , sessionData = data_ + , sessionData = SessionMap data_ , sessionCreatedAt = TI.addUTCTime (-1000) now , sessionAccessedAt = now } diff --git a/serversession/tests/Main.hs b/serversession/tests/Main.hs index b7b860c..8bdffc6 100644 --- a/serversession/tests/Main.hs +++ b/serversession/tests/Main.hs @@ -58,7 +58,7 @@ main = hspec $ parallel $ do return $ fromPathPiece (toPathPiece sid) Q.=== Just sid it "does not accept as valid some example invalid session IDs" $ do - let parse = fromPathPiece :: T.Text -> Maybe SessionId + let parse = fromPathPiece :: T.Text -> Maybe (SessionId SessionMap) parse "" `shouldBe` Nothing parse "123456789-123456789-123" `shouldBe` Nothing parse "123456789-123456789-12345" `shouldBe` Nothing @@ -95,7 +95,7 @@ main = hspec $ parallel $ do let point1 = 0.1 {- second -} :: Double now <- TI.getCurrentTime abs (realToFrac $ TI.diffUTCTime now time) `shouldSatisfy` (< point1) - sessionMap `shouldBe` M.empty + sessionMap `shouldBe` TNTSessionData msession `shouldSatisfy` isNothing it "returns empty session and token when the session ID cookie is not present" $ do @@ -119,7 +119,7 @@ main = hspec $ parallel $ do let session = Session { sessionKey = S "123456789-123456789-1234" , sessionAuthId = Just authId - , sessionData = M.fromList [("a", "b"), ("c", "d")] + , sessionData = mkSessionMap [("a", "b"), ("c", "d")] , sessionCreatedAt = TI.addUTCTime (-10) fakenow , sessionAccessedAt = TI.addUTCTime (-5) fakenow } @@ -127,7 +127,7 @@ main = hspec $ parallel $ do st <- createState =<< prepareMockStorage [session] (retSessionMap, SaveSessionToken msession _now) <- loadSession st (Just $ B8.pack $ T.unpack $ unS $ sessionKey session) - retSessionMap `shouldBe` M.insert (authKey st) authId (sessionData session) + retSessionMap `shouldBe` onSM (M.insert (authKey st) authId) (sessionData session) msession `shouldBe` Just session describe "checkExpired" $ do @@ -214,31 +214,31 @@ main = hspec $ parallel $ do it "works for a complex example" $ do sto <- emptyMockStorage st <- createState sto - saveSession st (SaveSessionToken Nothing fakenow) M.empty `shouldReturn` Nothing + saveSession st (SaveSessionToken Nothing fakenow) emptySM `shouldReturn` Nothing getMockOperations sto `shouldReturn` [] - let m1 = M.singleton "a" "b" + let m1 = mkSessionMap [("a", "b")] Just session1 <- saveSession st (SaveSessionToken Nothing fakenow) m1 sessionAuthId session1 `shouldBe` Nothing sessionData session1 `shouldBe` m1 getMockOperations sto `shouldReturn` [InsertSession session1] - let m2 = M.insert (authKey st) "john" m1 + let m2 = onSM (M.insert (authKey st) "john") m1 Just session2 <- saveSession st (SaveSessionToken (Just session1) fakenow) m2 sessionAuthId session2 `shouldBe` Just "john" sessionData session2 `shouldBe` m1 sessionKey session2 == sessionKey session1 `shouldBe` False getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session1), InsertSession session2] - let m3 = M.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser) m2 + let m3 = onSM (M.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser)) m2 Just session3 <- saveSession st (SaveSessionToken (Just session2) fakenow) m3 session3 `shouldBe` session2 { sessionKey = sessionKey session3 } getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session2), DeleteAllSessionsOfAuthId "john", InsertSession session3] - let m4 = M.insert "x" "y" m2 + let m4 = onSM (M.insert "x" "y") m2 Just session4 <- saveSession st (SaveSessionToken (Just session3) fakenow) m4 - session4 `shouldBe` session3 { sessionData = M.delete (authKey st) m4 } + session4 `shouldBe` session3 { sessionData = onSM (M.delete (authKey st)) m4 } getMockOperations sto `shouldReturn` [ReplaceSession session4] Just session5 <- saveSession st (SaveSessionToken (Just session4) (TI.addUTCTime 10 fakenow)) m4 @@ -250,18 +250,18 @@ main = hspec $ parallel $ do let oldSession = Session { sessionKey = S "123456789-123456789-1234" , sessionAuthId = authId - , sessionData = M.empty + , sessionData = emptySM , sessionCreatedAt = TI.addUTCTime (-10) fakenow , sessionAccessedAt = TI.addUTCTime (-5) fakenow } sto <- prepareMockStorage [oldSession] st <- createState sto - return (oldSession, sto, st) + return (oldSession, sto :: MockStorage SessionMap, st) allEdges = let x = [Nothing, Just "john", Just "jane"] in (,) <$> x <*> x it "does not invalidate when not changing auth ID nor explicitly requesting" $ do forM_ [Nothing, Just "john"] $ \authId -> do (session, sto, st) <- prepareInvalidateIfNeeded authId - let d = DecomposedSession authId DoNotForceInvalidate M.empty + let d = DecomposedSession authId DoNotForceInvalidate emptySM invalidateIfNeeded st (Just session) d `shouldReturn` Just session getMockOperations sto `shouldReturn` [] @@ -270,21 +270,21 @@ main = hspec $ parallel $ do , (Just "admin", Nothing) , (Nothing, Just "joe") ] $ \edgeTransition -> do (session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition) - let d = DecomposedSession (snd edgeTransition) DoNotForceInvalidate M.empty + let d = DecomposedSession (snd edgeTransition) DoNotForceInvalidate emptySM invalidateIfNeeded st (Just session) d `shouldReturn` Nothing getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session)] it "invalidates the current session when CurrentSessionId is forced" $ do forM_ allEdges $ \edgeTransition -> do (session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition) - let d = DecomposedSession (snd edgeTransition) CurrentSessionId M.empty + let d = DecomposedSession (snd edgeTransition) CurrentSessionId emptySM invalidateIfNeeded st (Just session) d `shouldReturn` Nothing getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session)] it "invalidates all of the user's sessions when AllSessionIdsOfLoggedUser is forced" $ do forM_ allEdges $ \edgeTransition -> do (session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition) - let d = DecomposedSession (snd edgeTransition) AllSessionIdsOfLoggedUser M.empty + let d = DecomposedSession (snd edgeTransition) AllSessionIdsOfLoggedUser emptySM invalidateIfNeeded st (Just session) d `shouldReturn` Nothing let expected = DeleteSession (sessionKey session) : maybe [] ((:[]) . DeleteAllSessionsOfAuthId) (snd edgeTransition) @@ -296,19 +296,19 @@ main = hspec $ parallel $ do let oldSession = Session { sessionKey = S "123456789-123456789-1234" , sessionAuthId = Just "auth" - , sessionData = M.fromList [("a", "b"), ("c", "d")] + , sessionData = mkSessionMap [("a", "b"), ("c", "d")] , sessionCreatedAt = TI.addUTCTime (-10) fakenow , sessionAccessedAt = TI.addUTCTime (-5) fakenow } sto <- prepareMockStorage [oldSession] st <- createState sto - return (oldSession, sto, st) - emptyDecomp = DecomposedSession Nothing DoNotForceInvalidate M.empty + return (oldSession, sto :: MockStorage SessionMap, st) + emptyDecomp = DecomposedSession Nothing DoNotForceInvalidate emptySM it "inserts new sessions when there wasn't an old one" $ do sto <- emptyMockStorage - st <- createState sto + st <- createState (sto :: MockStorage SessionMap) let d = DecomposedSession a DoNotForceInvalidate m - m = M.fromList [("a", "b"), ("c", "d")] + m = mkSessionMap [("a", "b"), ("c", "d")] a = Just "auth" Just session <- saveSessionOnDb st fakenow Nothing d getMockOperations sto `shouldReturn` [InsertSession session] @@ -320,7 +320,7 @@ main = hspec $ parallel $ do it "replaces sesssions when there was an old one" $ do (oldSession, sto, st) <- prepareSaveSessionOnDb let d = DecomposedSession Nothing DoNotForceInvalidate m - m = M.fromList [("a", "b"), ("x", "y")] + m = mkSessionMap [("a", "b"), ("x", "y")] Just session <- saveSessionOnDb st fakenow (Just oldSession) d getMockOperations sto `shouldReturn` [ReplaceSession session] session `shouldBe` oldSession @@ -336,7 +336,7 @@ main = hspec $ parallel $ do it "saves session if it's empty but there was an old one" $ do (oldSession, sto, st) <- prepareSaveSessionOnDb - let newSession = oldSession { sessionData = M.empty + let newSession = oldSession { sessionData = emptySM , sessionAuthId = Nothing , sessionAccessedAt = fakenow } saveSessionOnDb st fakenow (Just oldSession) emptyDecomp `shouldReturn` Just newSession @@ -356,46 +356,42 @@ main = hspec $ parallel $ do saveSessionOnDb st (t 1) (Just session1) d `shouldReturn` Just session2 getMockOperations sto `shouldReturn` [ReplaceSession session2] - describe "decomposeSession" $ do + describe "decomposeSession/SessionMap" $ do + let authKey_ = authKey stnull + prop "it is sane when not finding auth key or force invalidate key" $ \data_ -> let sessionMap = mkSessionMap $ filter (notSpecial . fst) $ data_ notSpecial = flip notElem [authKey stnull, forceInvalidateKey] . T.pack - in decomposeSession stnull sessionMap `shouldBe` + in decomposeSession authKey_ sessionMap `shouldBe` DecomposedSession Nothing DoNotForceInvalidate sessionMap prop "parses the force invalidate key" $ \data_ -> - let sessionMap v = M.insert forceInvalidateKey (B8.pack $ show v) $ mkSessionMap data_ + let sessionMap v = onSM (M.insert forceInvalidateKey (B8.pack $ show v)) $ mkSessionMap data_ allForces = [minBound..maxBound] :: [ForceInvalidate] - test v = dsForceInvalidate (decomposeSession stnull $ sessionMap v) Q.=== v + test v = dsForceInvalidate (decomposeSession authKey_ $ sessionMap v) Q.=== v in Q.conjoin (test <$> allForces) it "removes the auth key" $ do let m = M.singleton "a" "b"; m' = M.insert (authKey stnull) "x" m - decomposeSession stnull m' `shouldBe` - DecomposedSession (Just "x") DoNotForceInvalidate m + decomposeSession authKey_ (SessionMap m') `shouldBe` + DecomposedSession (Just "x") DoNotForceInvalidate (SessionMap m) - describe "toSessionMap" $ do - let mkSession authId data_ = Session - { sessionKey = error "irrelevant 1" - , sessionAuthId = authId - , sessionData = mkSessionMap data_ - , sessionCreatedAt = error "irrelevant 2" - , sessionAccessedAt = error "irrelevant 3" - } + describe "recomposeSession/SessionMap" $ do + let authKey_ = authKey stnull prop "does not change session data for sessions without auth ID" $ \data_ -> - let s = mkSession Nothing data_ - in toSessionMap stnull s Q.=== sessionData s + let s = mkSessionMap data_ + in recomposeSession authKey_ Nothing s Q.=== s prop "adds (overwriting) the auth ID to the session data" $ \authId_ data_ -> - let s = mkSession (Just authId) ((T.unpack k, "foo") : data_) - k = authKey stnull + let s = mkSessionMap ((T.unpack authKey_, "foo") : data_) authId = B8.pack authId_ - in toSessionMap stnull s Q.=== M.adjust (const authId) k (sessionData s) + in recomposeSession authKey_ (Just authId) s + Q.=== onSM (M.adjust (const authId) authKey_) s describe "MockStorage" $ do sto <- runIO emptyMockStorage @@ -404,7 +400,19 @@ main = hspec $ parallel $ do -- | Used to generate session maps on QuickCheck properties. mkSessionMap :: [(String, String)] -> SessionMap -mkSessionMap = M.fromList . map (T.pack *** B8.pack) +mkSessionMap = SessionMap . M.fromList . map (T.pack *** B8.pack) + + +-- | Apply a function to a 'SessionMap'. +onSM + :: (M.Map T.Text B8.ByteString -> M.Map T.Text B8.ByteString) + -> (SessionMap -> SessionMap) +onSM f = SessionMap . f . unSessionMap + + +-- | Empty 'SessionMap'. +emptySM :: SessionMap +emptySM = emptySession ---------------------------------------------------------------------- @@ -416,6 +424,7 @@ data TNTStorage = TNTStorage deriving (Typeable) instance Storage TNTStorage where type TransactionM TNTStorage = IO + type SessionData TNTStorage = TNTSessionData runTransactionM _ = id getSession = explode "getSession" deleteSession = explode "deleteSession" @@ -436,29 +445,52 @@ data TNTExplosion = TNTExplosion String String deriving (Show, Typeable) instance E.Exception TNTExplosion where +-- | Session data that explodes if it's used. Doesn't explode on +-- 'emptySession'. +data TNTSessionData = TNTSessionData deriving (Eq, Show, Typeable) + +instance IsSessionData TNTSessionData where + type Decomposed TNTSessionData = () + emptySession = TNTSessionData + isSameDecomposed _ = curry (explodeD "isSameDecomposed") + decomposeSession = curry (explodeD "decomposeSession") + recomposeSession = (curry . curry) (explodeD "recomposeSession") + isDecomposedEmpty _ = explodeD "isDecomposedEmpty" + + +-- | Implementation of all 'IsSessionData' methods of +-- 'TNTSessionData'. +explodeD :: Show a => String -> a -> b +explodeD fun = E.throw . TNTExplosion fun . show + + ---------------------------------------------------------------------- -- | A mock operation that was executed. -data MockOperation = - GetSession SessionId - | DeleteSession SessionId +data MockOperation sess = + GetSession (SessionId sess) + | DeleteSession (SessionId sess) | DeleteAllSessionsOfAuthId AuthId - | InsertSession Session - | ReplaceSession Session - deriving (Eq, Show, Typeable) + | InsertSession (Session sess) + | ReplaceSession (Session sess) + deriving (Typeable) + +deriving instance Eq (Decomposed sess) => Eq (MockOperation sess) +deriving instance Show (Decomposed sess) => Show (MockOperation sess) -- | A mock storage used just for testing. -data MockStorage = +data MockStorage sess = MockStorage - { mockSessions :: I.IORef (M.Map SessionId Session) - , mockOperations :: I.IORef [MockOperation] + { mockSessions :: I.IORef (M.Map (SessionId sess) (Session sess)) + , mockOperations :: I.IORef [MockOperation sess] } deriving (Typeable) -instance Storage MockStorage where - type TransactionM MockStorage = IO +instance IsSessionData sess => Storage (MockStorage sess) where + type TransactionM (MockStorage sess) = IO + type SessionData (MockStorage sess) = sess runTransactionM _ = id getSession sto sid = do -- We need to use atomicModifyIORef instead of readIORef @@ -478,7 +510,7 @@ instance Storage MockStorage where M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap in maybe (newMap, return ()) - (\oldVal -> (oldMap, E.throwIO $ SessionAlreadyExists oldVal session)) + (\oldVal -> (oldMap, mockThrow $ SessionAlreadyExists oldVal session)) moldVal addMockOperation sto (InsertSession session) replaceSession sto session = do @@ -486,14 +518,22 @@ instance Storage MockStorage where let (moldVal, newMap) = M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap in maybe - (oldMap, E.throwIO $ SessionDoesNotExist session) + (oldMap, mockThrow $ SessionDoesNotExist session) (const (newMap, return ())) moldVal addMockOperation sto (ReplaceSession session) +-- | Specialization of 'E.throwIO' for 'MockStorage'. +mockThrow + :: IsSessionData sess + => StorageException (MockStorage sess) + -> TransactionM (MockStorage sess) a +mockThrow = E.throwIO + + -- | Creates empty mock storage. -emptyMockStorage :: IO MockStorage +emptyMockStorage :: IO (MockStorage sess) emptyMockStorage = MockStorage <$> I.newIORef M.empty @@ -501,7 +541,7 @@ emptyMockStorage = -- | Creates mock storage with the given sessions already existing. -prepareMockStorage :: [Session] -> IO MockStorage +prepareMockStorage :: [Session sess] -> IO (MockStorage sess) prepareMockStorage sessions = do sto <- emptyMockStorage I.writeIORef (mockSessions sto) (M.fromList [(sessionKey s, s) | s <- sessions]) @@ -510,10 +550,10 @@ prepareMockStorage sessions = do -- | Get the list of mock operations that were made and clear -- them. The operations are listed in chronological order. -getMockOperations :: MockStorage -> IO [MockOperation] +getMockOperations :: MockStorage sess -> IO [MockOperation sess] getMockOperations = flip I.atomicModifyIORef' ((,) [] . reverse) . mockOperations -- | Add a mock operations to the log. -addMockOperation :: MockStorage -> MockOperation -> IO () +addMockOperation :: MockStorage sess -> MockOperation sess -> IO () addMockOperation sto op = I.atomicModifyIORef' (mockOperations sto) $ \ops -> (op:ops, ())