Generalize session data (huge commit).

This commit is contained in:
Felipe Lessa 2015-05-31 11:06:52 -03:00
parent 821016a382
commit 3e33c58af0
No known key found for this signature in database
GPG Key ID: A764D1843E966829
24 changed files with 1125 additions and 346 deletions

View File

@ -27,9 +27,12 @@ library
Web.ServerSession.Backend.Acid
Web.ServerSession.Backend.Acid.Internal
extensions:
ConstraintKinds
DeriveDataTypeable
FlexibleContexts
TemplateHaskell
TypeFamilies
UndecidableInstances
ghc-options: -Wall

View File

@ -26,14 +26,16 @@ module Web.ServerSession.Backend.Acid.Internal
import Control.Monad.Reader (ask)
import Control.Monad.State (get, modify', put)
import Data.Acid (AcidState, Query, Update, makeAcidic, query, update)
import Data.SafeCopy (deriveSafeCopy, base)
import Data.Acid
import Data.Acid.Advanced
import Data.SafeCopy
import Data.Typeable (Typeable)
import qualified Control.Exception as E
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Web.ServerSession.Core as SS
import qualified Web.ServerSession.Core.Internal as SSI
----------------------------------------------------------------------
@ -41,13 +43,13 @@ import qualified Web.ServerSession.Core as SS
-- | Map from session IDs to sessions. The most important map,
-- allowing us efficient access to a session given its ID.
type SessionIdToSession = M.Map SS.SessionId SS.Session
type SessionIdToSession sess = M.Map (SS.SessionId sess) (SS.Session sess)
-- | Map from auth IDs to session IDs. Allow us to invalidate
-- all sessions of given user without having to iterate through
-- the whole 'SessionIdToSession' map.
type AuthIdToSessionId = M.Map SS.AuthId (S.Set SS.SessionId)
type AuthIdToSessionId sess = M.Map SS.AuthId (S.Set (SS.SessionId sess))
-- | The current sessions.
@ -55,33 +57,37 @@ type AuthIdToSessionId = M.Map SS.AuthId (S.Set SS.SessionId)
-- Besides the obvious map from session IDs to sessions, we also
-- maintain a map of auth IDs to session IDs. This allow us to
-- quickly invalidate all sessions of a given user.
data ServerSessionAcidState =
data ServerSessionAcidState sess =
ServerSessionAcidState
{ sessionIdToSession :: !SessionIdToSession
, authIdToSessionId :: !AuthIdToSessionId
} deriving (Show, Typeable)
deriveSafeCopy 0 'base ''SS.SessionId -- dangerous!
deriveSafeCopy 0 'base ''SS.Session -- dangerous!
deriveSafeCopy 0 'base ''ServerSessionAcidState
{ sessionIdToSession :: !(SessionIdToSession sess)
, authIdToSessionId :: !(AuthIdToSessionId sess)
} deriving (Typeable)
-- | Empty 'ServerSessionAcidState' used to bootstrap the 'AcidState'.
emptyState :: ServerSessionAcidState
emptyState :: ServerSessionAcidState sess
emptyState = ServerSessionAcidState M.empty M.empty
-- | Remove the given 'SessionId' from the set of the given
-- 'AuthId' on the map. Does not do anything if no 'AuthId' is
-- provided.
removeSessionFromAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId -> AuthIdToSessionId
removeSessionFromAuthId
:: SS.SessionId sess
-> Maybe SS.AuthId
-> AuthIdToSessionId sess
-> AuthIdToSessionId sess
removeSessionFromAuthId sid = maybe id (M.update (nothingfy . S.delete sid))
-- | Insert the given session ID as being part of the given auth
-- ID. Conceptually the opposite of 'removeSessionFromAuthId'.
-- Does not do anything if no 'AuthId' is provided.
insertSessionForAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId -> AuthIdToSessionId
insertSessionForAuthId
:: SS.SessionId sess
-> Maybe SS.AuthId
-> AuthIdToSessionId sess
-> AuthIdToSessionId sess
insertSessionForAuthId sid = maybe id (flip (M.insertWith S.union) (S.singleton sid))
@ -93,13 +99,61 @@ nothingfy s = if S.null s then Nothing else Just s
----------------------------------------------------------------------
deriveSafeCopy 0 'base ''SS.SessionMap
-- | We can't @deriveSafeCopy 0 'base ''SS.SessionId@ as
-- otherwise we'd require an unneeded @SafeCopy sess@.
instance SafeCopy (SS.SessionId sess) where
putCopy = contain . safePut . SSI.unS
getCopy = contain $ SSI.S <$> safeGet
-- | We can't @deriveSafeCopy 0 'base ''SS.Session@ due to the
-- required context.
instance SafeCopy (SS.Decomposed sess) => SafeCopy (SS.Session sess) where
putCopy (SS.Session key authId data_ createdAt accessedAt) = contain $ do
put_t <- getSafePut
safePut key
safePut authId
safePut data_
put_t createdAt
put_t accessedAt
getCopy = contain $ do
get_t <- getSafeGet
SS.Session
<$> safeGet
<*> safeGet
<*> safeGet
<*> get_t
<*> get_t
-- | We can't @deriveSafeCopy 0 'base ''ServerSessionAcidState@ due
-- to the required context.
instance SafeCopy (SS.Decomposed sess) => SafeCopy (ServerSessionAcidState sess) where
putCopy (ServerSessionAcidState sits aits) = contain $ do
safePut sits
safePut aits
getCopy = contain $ ServerSessionAcidState <$> safeGet <*> safeGet
----------------------------------------------------------------------
-- | Get the session for the given session ID.
getSession :: SS.SessionId -> Query ServerSessionAcidState (Maybe SS.Session)
getSession
:: SS.Storage (AcidStorage sess)
=> SS.SessionId sess
-> Query (ServerSessionAcidState sess) (Maybe (SS.Session sess))
getSession sid = M.lookup sid . sessionIdToSession <$> ask
-- | Delete the session with given session ID.
deleteSession :: SS.SessionId -> Update ServerSessionAcidState ()
deleteSession
:: SS.Storage (AcidStorage sess)
=> SS.SessionId sess
-> Update (ServerSessionAcidState sess) ()
deleteSession sid = do
let removeSession = M.updateLookupWithKey (\_ _ -> Nothing) sid
modify' $ \state ->
@ -110,7 +164,10 @@ deleteSession sid = do
-- | Delete all sessions of the given auth ID.
deleteAllSessionsOfAuthId :: SS.AuthId -> Update ServerSessionAcidState ()
deleteAllSessionsOfAuthId
:: SS.Storage (AcidStorage sess)
=> SS.AuthId
-> Update (ServerSessionAcidState sess) ()
deleteAllSessionsOfAuthId authId = do
let removeSession = maybe id (flip M.difference . M.fromSet (const ()))
removeAuth = M.updateLookupWithKey (\_ _ -> Nothing) authId
@ -121,11 +178,14 @@ deleteAllSessionsOfAuthId authId = do
-- | Insert a new session.
insertSession :: SS.Session -> Update ServerSessionAcidState ()
insertSession
:: SS.Storage (AcidStorage sess)
=> SS.Session sess
-> Update (ServerSessionAcidState sess) ()
insertSession session = do
let insertSess s =
let (mold, new) = M.insertLookupWithKey (\_ v _ -> v) sid session s
in maybe new (\old -> E.throw $ SS.SessionAlreadyExists old session) mold
in maybe new (\old -> throwAS $ SS.SessionAlreadyExists old session) mold
insertAuth = insertSessionForAuthId sid (SS.sessionAuthId session)
sid = SS.sessionKey session
modify' $ \state ->
@ -135,14 +195,17 @@ insertSession session = do
-- | Replace the contents of a session.
replaceSession :: SS.Session -> Update ServerSessionAcidState ()
replaceSession
:: SS.Storage (AcidStorage sess)
=> SS.Session sess
-> Update (ServerSessionAcidState sess) ()
replaceSession session = do
-- Check that the old session exists while replacing it.
ServerSessionAcidState sits aits <- get
let (moldSession, sits') = M.insertLookupWithKey (\_ v _ -> v) sid session sits
sid = SS.sessionKey session
case moldSession of
Nothing -> E.throw $ SS.SessionDoesNotExist session
Nothing -> throwAS $ SS.SessionDoesNotExist session
Just oldSession -> do
-- Remove/insert the old auth ID from the map if needed.
let modAits | oldAuthId == newAuthId = id
@ -155,27 +218,103 @@ replaceSession session = do
put (ServerSessionAcidState sits' aits')
-- | Specialization of 'E.throw' for 'AcidStorage'.
throwAS
:: SS.Storage (AcidStorage sess)
=> SS.StorageException (AcidStorage sess)
-> a
throwAS = E.throw
----------------------------------------------------------------------
makeAcidic ''ServerSessionAcidState ['getSession, 'deleteSession, 'deleteAllSessionsOfAuthId, 'insertSession, 'replaceSession]
-- | Session storage backend using @acid-state@.
newtype AcidStorage =
newtype AcidStorage sess =
AcidStorage
{ acidState :: AcidState ServerSessionAcidState
{ acidState :: AcidState (ServerSessionAcidState sess)
-- ^ Open 'AcidState' of server sessions.
} deriving (Typeable)
-- | We do not provide any ACID guarantees for different actions
-- running inside the same @TransactionM AcidStorage@.
instance SS.Storage AcidStorage where
type TransactionM AcidStorage = IO
instance ( SS.IsSessionData sess
, SafeCopy sess
, SafeCopy (SS.Decomposed sess)
) => SS.Storage (AcidStorage sess) where
type SessionData (AcidStorage sess) = sess
type TransactionM (AcidStorage sess) = IO
runTransactionM = const id
getSession (AcidStorage s) = query s . GetSession
deleteSession (AcidStorage s) = update s . DeleteSession
deleteAllSessionsOfAuthId (AcidStorage s) = update s . DeleteAllSessionsOfAuthId
insertSession (AcidStorage s) = update s . InsertSession
replaceSession (AcidStorage s) = update s . ReplaceSession
----------------------------------------------------------------------
-- makeAcidic can't handle type variables, so we have to do
-- everything by hand. :(
data GetSession sess = GetSession (SS.SessionId sess)
data DeleteSession sess = DeleteSession (SS.SessionId sess)
data DeleteAllSessionsOfAuthId sess = DeleteAllSessionsOfAuthId SS.AuthId
data InsertSession sess = InsertSession (SS.Session sess)
data ReplaceSession sess = ReplaceSession (SS.Session sess)
instance SafeCopy (GetSession sess) where
putCopy (GetSession v) = contain $ safePut v
getCopy = contain $ GetSession <$> safeGet
instance SafeCopy (DeleteSession sess) where
putCopy (DeleteSession v) = contain $ safePut v
getCopy = contain $ DeleteSession <$> safeGet
instance SafeCopy (DeleteAllSessionsOfAuthId sess) where
putCopy (DeleteAllSessionsOfAuthId v) = contain $ safePut v
getCopy = contain $ DeleteAllSessionsOfAuthId <$> safeGet
instance SafeCopy (SS.Decomposed sess) => SafeCopy (InsertSession sess) where
putCopy (InsertSession v) = contain $ safePut v
getCopy = contain $ InsertSession <$> safeGet
instance SafeCopy (SS.Decomposed sess) => SafeCopy (ReplaceSession sess) where
putCopy (ReplaceSession v) = contain $ safePut v
getCopy = contain $ ReplaceSession <$> safeGet
type AcidContext sess =
( SS.IsSessionData sess
, SafeCopy sess
, SafeCopy (SS.Decomposed sess) )
instance AcidContext sess => QueryEvent (GetSession sess)
instance AcidContext sess => UpdateEvent (DeleteSession sess)
instance AcidContext sess => UpdateEvent (DeleteAllSessionsOfAuthId sess)
instance AcidContext sess => UpdateEvent (InsertSession sess)
instance AcidContext sess => UpdateEvent (ReplaceSession sess)
instance AcidContext sess => Method (GetSession sess) where
type MethodResult (GetSession sess) = Maybe (SS.Session sess)
type MethodState (GetSession sess) = ServerSessionAcidState sess
instance AcidContext sess => Method (DeleteSession sess) where
type MethodResult (DeleteSession sess) = ()
type MethodState (DeleteSession sess) = ServerSessionAcidState sess
instance AcidContext sess => Method (DeleteAllSessionsOfAuthId sess) where
type MethodResult (DeleteAllSessionsOfAuthId sess) = ()
type MethodState (DeleteAllSessionsOfAuthId sess) = ServerSessionAcidState sess
instance AcidContext sess => Method (InsertSession sess) where
type MethodResult (InsertSession sess) = ()
type MethodState (InsertSession sess) = ServerSessionAcidState sess
instance AcidContext sess => Method (ReplaceSession sess) where
type MethodResult (ReplaceSession sess) = ()
type MethodState (ReplaceSession sess) = ServerSessionAcidState sess
instance AcidContext sess => IsAcidic (ServerSessionAcidState sess) where
acidEvents =
[ QueryEvent $ \(GetSession sid) -> getSession sid
, UpdateEvent $ \(DeleteSession sid) -> deleteSession sid
, UpdateEvent $ \(DeleteAllSessionsOfAuthId authId) -> deleteAllSessionsOfAuthId authId
, UpdateEvent $ \(InsertSession session) -> insertSession session
, UpdateEvent $ \(ReplaceSession session) -> replaceSession session ]

View File

@ -24,7 +24,7 @@ library
, containers
, path-pieces
, persistent == 2.1.*
, persistent-template == 2.1.*
, tagged >= 0.8
, text
, time
, transformers
@ -37,14 +37,19 @@ library
extensions:
DeriveDataTypeable
EmptyDataDecls
FlexibleContexts
FlexibleInstances
GADTs
GeneralizedNewtypeDeriving
OverloadedStrings
PatternGuards
QuasiQuotes
RecordWildCards
ScopedTypeVariables
StandaloneDeriving
TemplateHaskell
TypeFamilies
UndecidableInstances
ghc-options: -Wall

View File

@ -22,9 +22,11 @@
-- share [mkPersist sqlSettings, mkSave \"entityDefs\"]
--
-- -- On Application.hs
-- import Data.Proxy (Proxy(..)) -- tagged package, or base from GHC 7.10 onwards
-- import Web.ServerSession.Core (SessionMap)
-- import Web.ServerSession.Backend.Persistent (serverSessionDefs)
--
-- mkMigrate \"migrateAll\" (serverSessionDefs ++ entityDefs)
-- mkMigrate \"migrateAll\" (serverSessionDefs (Proxy :: Proxy SessionMap) ++ entityDefs)
--
-- makeFoundation =
-- ...
@ -32,6 +34,8 @@
-- ...
-- @
--
-- If you're not using @SessionMap@, just change @Proxy@ type above.
--
-- If you forget to setup the migration above, this session
-- storage backend will fail at runtime as the required table
-- will not exist.

View File

@ -9,68 +9,284 @@ module Web.ServerSession.Backend.Persistent.Internal.Impl
, toPersistentSession
, fromPersistentSession
, SqlStorage(..)
, throwSS
) where
import Control.Monad (void)
import Control.Monad.IO.Class (liftIO)
import Data.Proxy (Proxy(..))
import Data.Time (UTCTime)
import Data.Typeable (Typeable)
import Database.Persist (PersistEntity(..))
import Database.Persist.TH (mkPersist, mkSave, persistLowerCase, share, sqlSettings)
import Web.PathPieces (PathPiece)
import Web.ServerSession.Core
import qualified Control.Exception as E
import qualified Data.Aeson as A
import qualified Data.Map as M
import qualified Data.Text as T
import qualified Database.Persist as P
import qualified Database.Persist.Sql as P
import Web.ServerSession.Backend.Persistent.Internal.Types
share
[mkPersist sqlSettings, mkSave "serverSessionDefs"]
[persistLowerCase|
PersistentSession json
key SessionId -- Session ID, primary key.
authId ByteStringJ Maybe -- Value of "_ID" session key.
session SessionMapJ -- Rest of the session data.
createdAt UTCTime -- When this session was created.
accessedAt UTCTime -- When this session was last accessed.
Primary key
deriving Eq Ord Show Typeable
|]
-- We can't use the Template Haskell since we want to generalize
-- some fields.
--
-- This is going to be a pain to upgrade when the next major
-- persistent version comes :(.
-- | Entity corresponding to a 'Session'.
--
-- We're bending @persistent@ in ways it wasn't expected to. In
-- particular, this entity is parametrized over the session type.
data PersistentSession sess =
PersistentSession
{ persistentSessionKey :: !(SessionId sess) -- ^ Session ID, primary key.
, persistentSessionAuthId :: !(Maybe ByteStringJ) -- ^ Value of "_ID" session key.
, persistentSessionSession :: !(Decomposed sess) -- ^ Rest of the session data.
, persistentSessionCreatedAt :: !UTCTime -- ^ When this session was created.
, persistentSessionAccessedAt :: !UTCTime -- ^ When this session was last accessed.
} deriving (Typeable)
deriving instance Eq (Decomposed sess) => Eq (PersistentSession sess)
deriving instance Ord (Decomposed sess) => Ord (PersistentSession sess)
deriving instance Show (Decomposed sess) => Show (PersistentSession sess)
type PersistentSessionId sess = Key (PersistentSession sess)
instance forall sess. P.PersistFieldSql (Decomposed sess) => P.PersistEntity (PersistentSession sess) where
type PersistEntityBackend (PersistentSession sess) = P.SqlBackend
data Unique (PersistentSession sess)
newtype Key (PersistentSession sess) =
PersistentSessionKey' {unPersistentSessionKey :: SessionId sess}
deriving ( Eq, Ord, Show, Read, PathPiece
, P.PersistField, P.PersistFieldSql, A.ToJSON, A.FromJSON )
data EntityField (PersistentSession sess) typ =
typ ~ PersistentSessionId sess => PersistentSessionId
| typ ~ SessionId sess => PersistentSessionKey
| typ ~ Maybe ByteStringJ => PersistentSessionAuthId
| typ ~ Decomposed sess => PersistentSessionSession
| typ ~ UTCTime => PersistentSessionCreatedAt
| typ ~ UTCTime => PersistentSessionAccessedAt
keyToValues = (:[]) . P.toPersistValue . unPersistentSessionKey
keyFromValues [x] | Right v <- P.fromPersistValue x = Right $ PersistentSessionKey' v
keyFromValues xs = Left $ T.pack $ "PersistentSession/keyFromValues: " ++ show xs
entityDef _
= P.EntityDef
(P.HaskellName "PersistentSession")
(P.DBName "persistent_session")
(pfd PersistentSessionId)
["json"]
[ pfd PersistentSessionKey
, pfd PersistentSessionAuthId
, pfd PersistentSessionSession
, pfd PersistentSessionCreatedAt
, pfd PersistentSessionAccessedAt ]
[]
[]
["Eq", "Ord", "Show", "Typeable"]
M.empty
False
where
pfd :: P.EntityField (PersistentSession sess) typ -> P.FieldDef
pfd = P.persistFieldDef
toPersistFields (PersistentSession a b c d e) =
[ P.SomePersistField a
, P.SomePersistField b
, P.SomePersistField c
, P.SomePersistField d
, P.SomePersistField e ]
fromPersistValues [a, b, c, d, e] =
PersistentSession
<$> err "key" (P.fromPersistValue a)
<*> err "authId" (P.fromPersistValue b)
<*> err "session" (P.fromPersistValue c)
<*> err "createdAt" (P.fromPersistValue d)
<*> err "accessedAt" (P.fromPersistValue e)
where
err :: T.Text -> Either T.Text a -> Either T.Text a
err s (Left r) = Left $ T.concat ["PersistentSession/fromPersistValues/", s, ": ", r]
err _ (Right v) = Right v
fromPersistValues x = Left $ T.pack $ "PersistentSession/fromPersistValues: " ++ show x
persistUniqueToFieldNames _ = error "Degenerate case, should never happen"
persistUniqueToValues _ = error "Degenerate case, should never happen"
persistUniqueKeys _ = []
persistFieldDef PersistentSessionId
= P.FieldDef
(P.HaskellName "Id")
(P.DBName "id")
(P.FTTypeCon
Nothing "PersistentSessionId")
(P.SqlOther "Composite Reference")
[]
True
(P.CompositeRef
(P.CompositeDef
[P.FieldDef
(P.HaskellName "key")
(P.DBName "key")
(P.FTTypeCon Nothing "SessionId")
(P.SqlOther "SqlType unset for key")
[]
True
P.NoReference]
[]))
persistFieldDef PersistentSessionKey
= P.FieldDef
(P.HaskellName "key")
(P.DBName "key")
(P.FTTypeCon Nothing "SessionId sess")
(P.sqlType (Proxy :: Proxy (SessionId sess)))
[]
True
P.NoReference
persistFieldDef PersistentSessionAuthId
= P.FieldDef
(P.HaskellName "authId")
(P.DBName "auth_id")
(P.FTTypeCon Nothing "ByteStringJ")
(P.sqlType (Proxy :: Proxy ByteStringJ))
["Maybe"]
True
P.NoReference
persistFieldDef PersistentSessionSession
= P.FieldDef
(P.HaskellName "session")
(P.DBName "session")
(P.FTTypeCon Nothing "Decomposed sess")
(P.sqlType (Proxy :: Proxy (Decomposed sess))) -- Important!
[]
True
P.NoReference
persistFieldDef PersistentSessionCreatedAt
= P.FieldDef
(P.HaskellName "createdAt")
(P.DBName "created_at")
(P.FTTypeCon Nothing "UTCTime")
(P.sqlType (Proxy :: Proxy UTCTime))
[]
True
P.NoReference
persistFieldDef PersistentSessionAccessedAt
= P.FieldDef
(P.HaskellName "accessedAt")
(P.DBName "accessed_at")
(P.FTTypeCon Nothing "UTCTime")
(P.sqlType (Proxy :: Proxy UTCTime))
[]
True
P.NoReference
persistIdField = PersistentSessionId
fieldLens PersistentSessionId = lensPTH
P.entityKey
(\(P.Entity _ v) k -> P.Entity k v)
fieldLens PersistentSessionKey = lensPTH
(persistentSessionKey . P.entityVal)
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionKey = x}))
fieldLens PersistentSessionAuthId = lensPTH
(persistentSessionAuthId . P.entityVal)
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionAuthId = x}))
fieldLens PersistentSessionSession = lensPTH
(persistentSessionSession . P.entityVal)
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionSession = x}))
fieldLens PersistentSessionCreatedAt = lensPTH
(persistentSessionCreatedAt . P.entityVal)
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionCreatedAt = x}))
fieldLens PersistentSessionAccessedAt = lensPTH
(persistentSessionAccessedAt . P.entityVal)
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionAccessedAt = x}))
-- | Copy-paste from @Database.Persist.TH@. Who needs lens anyway...
lensPTH :: Functor f => (s -> a) -> (s -> b -> t) -> (a -> f b) -> s -> f t
lensPTH sa sbt afb s = fmap (sbt s) (afb $ sa s)
instance A.ToJSON (Decomposed sess) => A.ToJSON (PersistentSession sess) where
toJSON (PersistentSession key authId session createdAt accessedAt) =
A.object
[ "key" A..= key
, "authId" A..= authId
, "session" A..= session
, "createdAt" A..= createdAt
, "accessedAt" A..= accessedAt ]
instance A.FromJSON (Decomposed sess) => A.FromJSON (PersistentSession sess) where
parseJSON (A.Object obj) =
PersistentSession
<$> obj A..: "key"
<*> obj A..: "authId"
<*> obj A..: "session"
<*> obj A..: "createdAt"
<*> obj A..: "accessedAt"
parseJSON _ = mempty
instance ( A.ToJSON (Decomposed sess)
, P.PersistFieldSql (Decomposed sess)
) => A.ToJSON (P.Entity (PersistentSession sess)) where
toJSON = P.entityIdToJSON
instance ( A.FromJSON (Decomposed sess)
, P.PersistFieldSql (Decomposed sess)
) => A.FromJSON (P.Entity (PersistentSession sess)) where
parseJSON = P.entityIdFromJSON
-- | Entity definitions needed to generate the SQL schema for
-- 'SqlStorage'. Example using 'SessionMap':
--
-- @
-- serverSessionDefs (Proxy :: Proxy SessionMap)
-- @
serverSessionDefs :: forall sess. PersistEntity (PersistentSession sess) => Proxy sess -> [P.EntityDef]
serverSessionDefs _ = [entityDef (Proxy :: Proxy (PersistentSession sess))]
-- | Generate a key to the entity from the session ID.
psKey :: SessionId -> Key PersistentSession
psKey :: SessionId sess -> Key (PersistentSession sess)
psKey = PersistentSessionKey'
-- | Convert from 'Session' to 'PersistentSession'.
toPersistentSession :: Session -> PersistentSession
toPersistentSession :: Session sess -> PersistentSession sess
toPersistentSession Session {..} =
PersistentSession
{ persistentSessionKey = sessionKey
, persistentSessionAuthId = fmap B sessionAuthId
, persistentSessionSession = M sessionData
, persistentSessionSession = sessionData
, persistentSessionCreatedAt = sessionCreatedAt
, persistentSessionAccessedAt = sessionAccessedAt
}
-- | Convert from 'PersistentSession' to 'Session'.
fromPersistentSession :: PersistentSession -> Session
fromPersistentSession :: PersistentSession sess -> Session sess
fromPersistentSession PersistentSession {..} =
Session
{ sessionKey = persistentSessionKey
, sessionAuthId = fmap unB persistentSessionAuthId
, sessionData = unM persistentSessionSession
, sessionData = persistentSessionSession
, sessionCreatedAt = persistentSessionCreatedAt
, sessionAccessedAt = persistentSessionAccessedAt
}
-- | SQL session storage backend using @persistent@.
newtype SqlStorage =
newtype SqlStorage sess =
SqlStorage
{ connPool :: P.ConnectionPool
-- ^ Pool of DB connections. You may use the same pool as
@ -78,22 +294,38 @@ newtype SqlStorage =
} deriving (Typeable)
instance Storage SqlStorage where
type TransactionM SqlStorage = P.SqlPersistT IO
instance forall sess.
( IsSessionData sess
, P.PersistFieldSql (Decomposed sess)
) => Storage (SqlStorage sess) where
type SessionData (SqlStorage sess) = sess
type TransactionM (SqlStorage sess) = P.SqlPersistT IO
runTransactionM = flip P.runSqlPool . connPool
getSession _ = fmap (fmap fromPersistentSession) . P.get . psKey
deleteSession _ = P.delete . psKey
deleteAllSessionsOfAuthId _ authId = P.deleteWhere [PersistentSessionAuthId P.==. Just (B authId)]
deleteAllSessionsOfAuthId _ authId =
P.deleteWhere [field P.==. Just (B authId)]
where
field :: EntityField (PersistentSession sess) (Maybe ByteStringJ)
field = PersistentSessionAuthId
insertSession s session = do
mold <- getSession s (sessionKey session)
maybe
(void $ P.insert $ toPersistentSession session)
(\old -> liftIO $ E.throwIO $ SessionAlreadyExists old session)
(\old -> throwSS $ SessionAlreadyExists old session)
mold
replaceSession _ session = do
let key = psKey $ sessionKey session
mold <- P.get key
maybe
(liftIO $ E.throwIO $ SessionDoesNotExist session)
(throwSS $ SessionDoesNotExist session)
(\_old -> void $ P.replace key $ toPersistentSession session)
mold
-- | Specialization of 'E.throwIO' for 'SqlStorage'.
throwSS
:: Storage (SqlStorage sess)
=> StorageException (SqlStorage sess)
-> TransactionM (SqlStorage sess) a
throwSS = liftIO . E.throwIO

View File

@ -5,7 +5,11 @@
-- Also exports orphan instances of @PersistField{,Sql} SessionId@.
module Web.ServerSession.Backend.Persistent.Internal.Types
( ByteStringJ(..)
, SessionMapJ(..)
-- * Orphan instances
-- ** SessionId
-- $orphanSessionId
-- ** SessionMap
-- $orphanSessionMap
) where
import Control.Arrow (first)
@ -29,12 +33,19 @@ import qualified Data.Text.Encoding as TE
----------------------------------------------------------------------
-- | Does not do sanity checks (DB is trusted).
instance PersistField SessionId where
-- $orphanSessionId
--
-- @
-- instance 'PersistField' ('SessionId' sess)
-- instance 'PersistFieldSql' ('SessionId' sess)
-- @
--
-- Does not do sanity checks (DB is trusted).
instance PersistField (SessionId sess) where
toPersistValue = toPersistValue . unS
fromPersistValue = fmap S . fromPersistValue
instance PersistFieldSql SessionId where
instance PersistFieldSql (SessionId sess) where
sqlType p = sqlType (fmap unS p)
@ -66,31 +77,41 @@ instance A.ToJSON ByteStringJ where
----------------------------------------------------------------------
-- | Newtype of a 'SessionMap' that serializes using @cereal@ on
-- $orphanSessionMap
--
-- @
-- instance 'PersistField' 'SessionMap'
-- instance 'PersistFieldSql' 'SessionMap'
-- instance 'S.Serialize' 'SessionMap'
-- instance 'A.FromJSON' 'SessionMap'
-- instance 'A.ToJSON' 'SessionMap'
-- @
-- 'PersistField' for 'SessionMap' serializes using @cereal@ on
-- the database. We tried to use @aeson@ but @cereal@ is twice
-- faster and uses half the memory for this use case.
newtype SessionMapJ = M { unM :: SessionMap }
deriving (Eq, Ord, Show, Read, Typeable)
instance PersistField SessionMapJ where
--
-- The JSON instance translates to objects using base64url for
-- the values of 'ByteString' (cf. 'ByteStringJ').
instance PersistField SessionMap where
toPersistValue = toPersistValue . S.encode
fromPersistValue = fromPersistValue >=> (either (Left . T.pack) Right . S.decode)
instance PersistFieldSql SessionMapJ where
instance PersistFieldSql SessionMap where
sqlType _ = SqlBlob
instance S.Serialize SessionMapJ where
put = S.put . map (first TE.encodeUtf8) . M.toAscList . unM
get = M . M.fromAscList . map (first TE.decodeUtf8) <$> S.get
instance S.Serialize SessionMap where
put = S.put . map (first TE.encodeUtf8) . M.toAscList . unSessionMap
get = SessionMap . M.fromAscList . map (first TE.decodeUtf8) <$> S.get
instance A.FromJSON SessionMapJ where
instance A.FromJSON SessionMap where
parseJSON = fmap fixup . A.parseJSON
where
fixup :: M.Map Text ByteStringJ -> SessionMapJ
fixup = M . fmap unB
fixup :: M.Map Text ByteStringJ -> SessionMap
fixup = SessionMap . fmap unB
instance A.ToJSON SessionMapJ where
instance A.ToJSON SessionMap where
toJSON = A.toJSON . mangle
where
mangle :: SessionMapJ -> M.Map Text ByteStringJ
mangle = fmap B . unM
mangle :: SessionMap -> M.Map Text ByteStringJ
mangle = fmap B . unSessionMap

View File

@ -3,17 +3,19 @@ module Main (main) where
import Control.Monad (forM_)
import Control.Monad.Logger (runStderrLoggingT, runNoLoggingT)
import Data.Pool (destroyAllResources)
import Data.Proxy (Proxy(..))
import Database.Persist.Postgresql (createPostgresqlPool)
import Database.Persist.Sqlite (createSqlitePool)
import Test.Hspec
import Web.ServerSession.Backend.Persistent
import Web.ServerSession.Core (SessionMap)
import Web.ServerSession.Core.StorageTests
import qualified Control.Exception as E
import qualified Database.Persist.TH as P
import qualified Database.Persist.Sql as P
P.mkMigrate "migrateAll" serverSessionDefs
P.mkMigrate "migrateAll" (serverSessionDefs (Proxy :: Proxy SessionMap))
main :: IO ()
main = hspec $

View File

@ -21,6 +21,7 @@ library
, containers
, hedis == 0.6.*
, path-pieces
, tagged >= 0.8
, text
, time >= 1.5
, transformers
@ -31,8 +32,10 @@ library
Web.ServerSession.Backend.Redis.Internal
extensions:
DeriveDataTypeable
FlexibleContexts
OverloadedStrings
RecordWildCards
ScopedTypeVariables
TypeFamilies
ghc-options: -Wall

View File

@ -2,6 +2,7 @@
module Web.ServerSession.Backend.Redis
( RedisStorage(..)
, RedisStorageException(..)
, RedisSession(..)
) where
import Web.ServerSession.Backend.Redis.Internal

View File

@ -9,6 +9,7 @@ module Web.ServerSession.Backend.Redis.Internal
, rSessionKey
, rAuthKey
, RedisSession(..)
, parseSession
, printSession
, parseUTCTime
@ -22,6 +23,7 @@ module Web.ServerSession.Backend.Redis.Internal
, deleteAllSessionsOfAuthIdImpl
, insertSessionImpl
, replaceSessionImpl
, throwRS
) where
import Control.Applicative ((<$))
@ -31,6 +33,7 @@ import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.List (partition)
import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy(..))
import Data.Typeable (Typeable)
import Web.PathPieces (toPathPiece)
import Web.ServerSession.Core
@ -49,7 +52,7 @@ import qualified Data.Time.Format as TI
-- | Session storage backend using Redis via the @hedis@ package.
newtype RedisStorage =
newtype RedisStorage sess =
RedisStorage
{ connPool :: R.Connection
-- ^ Connection pool to the Redis server.
@ -58,8 +61,9 @@ newtype RedisStorage =
-- | We do not provide any ACID guarantees for different actions
-- running inside the same @TransactionM RedisStorage@.
instance Storage RedisStorage where
type TransactionM RedisStorage = R.Redis
instance RedisSession sess => Storage (RedisStorage sess) where
type SessionData (RedisStorage sess) = sess
type TransactionM (RedisStorage sess) = R.Redis
runTransactionM = R.runRedis . connPool
getSession _ = getSessionImpl
deleteSession _ = deleteSessionImpl
@ -102,7 +106,7 @@ unwrap act = act >>= either (liftIO . E.throwIO . ExpectedRight) return
-- | Redis key for the given session ID.
rSessionKey :: SessionId -> ByteString
rSessionKey :: SessionId sess -> ByteString
rSessionKey = B.append "ssr:session:" . TE.encodeUtf8 . toPathPiece
@ -114,8 +118,39 @@ rAuthKey = B.append "ssr:authid:"
----------------------------------------------------------------------
-- | Class for data types that can be used as session data for
-- the Redis backend.
--
-- It should hold that
--
-- @
-- fromHash p . perm . toHash p === id
-- @
--
-- for all list permutations @perm :: [a] -> [a]@,
-- where @p :: Proxy sess@.
class IsSessionData sess => RedisSession sess where
-- | Transform a decomposed session into a Redis hash. Keys
-- will be prepended with @\"data:\"@ before being stored.
toHash :: Proxy sess -> Decomposed sess -> [(ByteString, ByteString)]
-- | Parse back a Redis hash into session data.
fromHash :: Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess
-- | Assumes that keys are UTF-8 encoded when parsing (which is
-- true if keys are always generated via @toHash@).
instance RedisSession SessionMap where
toHash _ = map (first TE.encodeUtf8) . M.toList . unSessionMap
fromHash _ = SessionMap . M.fromList . map (first TE.decodeUtf8)
-- | Parse a 'Session' from a Redis hash.
parseSession :: SessionId -> [(ByteString, ByteString)] -> Maybe Session
parseSession
:: forall sess. RedisSession sess
=> SessionId sess
-> [(ByteString, ByteString)]
-> Maybe (Session sess)
parseSession _ [] = Nothing
parseSession sid bss =
let (externalList, internalList) = partition (B8.isPrefixOf "data:" . fst) bss
@ -124,25 +159,26 @@ parseSession sid bss =
accessedAt = parseUTCTime $ lookup' "internal:accessedAt"
lookup' k = fromMaybe (error err) $ lookup k internalList
where err = "serversession-backend-redis/parseSession: missing key " ++ show k
sessionMap = M.fromList $ map (first $ TE.decodeUtf8 . removePrefix) externalList
data_ = fromHash p $ map (first removePrefix) externalList
where removePrefix bs = let ("data:", key) = B8.splitAt 5 bs in key
p = Proxy :: Proxy sess
in Just Session
{ sessionKey = sid
, sessionAuthId = authId
, sessionData = sessionMap
, sessionData = data_
, sessionCreatedAt = createdAt
, sessionAccessedAt = accessedAt
}
-- | Convert a 'Session' into a Redis hash.
printSession :: Session -> [(ByteString, ByteString)]
printSession :: forall sess. RedisSession sess => Session sess -> [(ByteString, ByteString)]
printSession Session {..} =
maybe id ((:) . (,) "internal:authId") sessionAuthId $
(:) ("internal:createdAt", printUTCTime sessionCreatedAt) $
(:) ("internal:accessedAt", printUTCTime sessionAccessedAt) $
map (first $ B8.append "data:" . TE.encodeUtf8) $
M.toList sessionData
map (first $ B8.append "data:") $
toHash (Proxy :: Proxy sess) sessionData
-- | Parse 'UTCTime' from a 'ByteString' stored on Redis. Uses
@ -177,12 +213,12 @@ batched f xs =
-- | Get the session for the given session ID.
getSessionImpl :: SessionId -> R.Redis (Maybe Session)
getSessionImpl :: RedisSession sess => SessionId sess -> R.Redis (Maybe (Session sess))
getSessionImpl sid = parseSession sid <$> unwrap (R.hgetall $ rSessionKey sid)
-- | Delete the session with given session ID.
deleteSessionImpl :: SessionId -> R.Redis ()
deleteSessionImpl :: RedisSession sess => SessionId sess -> R.Redis ()
deleteSessionImpl sid = do
msession <- getSessionImpl sid
case msession of
@ -196,18 +232,22 @@ deleteSessionImpl sid = do
-- | Remove the given 'SessionId' from the set of sessions of the
-- given 'AuthId'. Does not do anything if @Nothing@.
removeSessionFromAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m ()
removeSessionFromAuthId :: R.RedisCtx m f => SessionId sess -> Maybe AuthId -> m ()
removeSessionFromAuthId = fooSessionBarAuthId R.srem
-- | Insert the given 'SessionId' into the set of sessions of the
-- given 'AuthId'. Does not do anything if @Nothing@.
insertSessionForAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m ()
insertSessionForAuthId :: R.RedisCtx m f => SessionId sess -> Maybe AuthId -> m ()
insertSessionForAuthId = fooSessionBarAuthId R.sadd
-- | (Internal) Helper for 'removeSessionFromAuthId' and 'insertSessionForAuthId'
fooSessionBarAuthId
:: R.RedisCtx m f => (ByteString -> [ByteString] -> m (f Integer)) -> SessionId -> Maybe AuthId -> m ()
:: R.RedisCtx m f
=> (ByteString -> [ByteString] -> m (f Integer))
-> SessionId sess
-> Maybe AuthId
-> m ()
fooSessionBarAuthId _ _ Nothing = return ()
fooSessionBarAuthId fun sid (Just authId) = void $ fun (rAuthKey authId) [rSessionKey sid]
@ -220,13 +260,13 @@ deleteAllSessionsOfAuthIdImpl authId = do
-- | Insert a new session.
insertSessionImpl :: Session -> R.Redis ()
insertSessionImpl :: RedisSession sess => Session sess -> R.Redis ()
insertSessionImpl session = do
-- Check that no old session exists.
let sid = sessionKey session
moldSession <- getSessionImpl sid
case moldSession of
Just oldSession -> liftIO $ E.throwIO $ SessionAlreadyExists oldSession session
Just oldSession -> throwRS $ SessionAlreadyExists oldSession session
Nothing -> do
transaction $ do
let sk = rSessionKey sid
@ -237,13 +277,13 @@ insertSessionImpl session = do
-- | Replace the contents of a session.
replaceSessionImpl :: Session -> R.Redis ()
replaceSessionImpl :: RedisSession sess => Session sess -> R.Redis ()
replaceSessionImpl session = do
-- Check that the old session exists.
let sid = sessionKey session
moldSession <- getSessionImpl sid
case moldSession of
Nothing -> liftIO $ E.throwIO $ SessionDoesNotExist session
Nothing -> throwRS $ SessionDoesNotExist session
Just oldSession -> do
transaction $ do
-- Delete the old session and set the new one.
@ -259,3 +299,11 @@ replaceSessionImpl session = do
insertSessionForAuthId sid newAuthId
return (() <$ r)
-- | Specialization of 'E.throwIO' for 'RedisStorage'.
throwRS
:: Storage (RedisStorage sess)
=> StorageException (RedisStorage sess)
-> R.Redis a
throwRS = liftIO . E.throwIO

View File

@ -32,7 +32,10 @@ library
Web.ServerSession.Frontend.Snap
Web.ServerSession.Frontend.Snap.Internal
extensions:
FlexibleContexts
OverloadedStrings
TypeFamilies
UndecidableInstances
ghc-options: -Wall
source-repository head

View File

@ -3,6 +3,7 @@ module Web.ServerSession.Frontend.Snap
( -- * Using server-side sessions
initServerSessionManager
, simpleServerSessionManager
, SnapSession(..)
-- * Invalidating session IDs
, forceInvalidate
, ForceInvalidate(..)

View File

@ -3,6 +3,7 @@
module Web.ServerSession.Frontend.Snap.Internal
( initServerSessionManager
, simpleServerSessionManager
, SnapSession(..)
, ServerSessionManager(..)
, currentSessionMap
, modifyCurrentSession
@ -33,7 +34,10 @@ import qualified Snap.Snaplet.Session.SessionManager as S
-- | Create a new 'ServerSessionManager' using the given 'State'.
initServerSessionManager :: Storage s => IO (State s) -> S.SnapletInit b S.SessionManager
initServerSessionManager
:: (Storage sto, SnapSession (SessionData sto))
=> IO (State sto)
-> S.SnapletInit b S.SessionManager
initServerSessionManager mkState =
S.makeSnaplet "ServerSession"
"Snaplet providing sessions via server-side storage."
@ -51,17 +55,67 @@ initServerSessionManager mkState =
-- | Simplified version of 'initServerSessionManager', sufficient
-- for most needs.
simpleServerSessionManager :: Storage s => IO s -> (State s -> State s) -> S.SnapletInit b S.SessionManager
simpleServerSessionManager
:: (Storage sto, SessionData sto ~ SessionMap)
=> IO sto
-> (State sto -> State sto)
-> S.SnapletInit b S.SessionManager
simpleServerSessionManager mkStorage opts =
initServerSessionManager (fmap opts . createState =<< mkStorage)
----------------------------------------------------------------------
-- | Class for data types that implement the operations Snap
-- expects sessions to support.
class IsSessionData sess => SnapSession sess where
ssInsert :: Text -> Text -> sess -> sess
ssLookup :: Text -> sess -> Maybe Text
ssDelete :: Text -> sess -> sess
ssToList :: sess -> [(Text, Text)]
ssInsertCsrf :: Text -> sess -> sess
ssLookupCsrf :: sess -> Maybe Text
ssForceInvalidate :: ForceInvalidate -> sess -> sess
-- | Uses 'csrfKey'.
instance SnapSession SessionMap where
ssInsert key val = onSM (M.insert key (TE.encodeUtf8 val))
ssLookup key = fmap TE.decodeUtf8 . M.lookup key . unSessionMap
ssDelete key = onSM (M.delete key)
ssToList =
-- Remove the CSRF key from the list as the current
-- clientsession backend doesn't return it.
fmap (second TE.decodeUtf8) .
M.toList .
M.delete csrfKey .
unSessionMap
ssInsertCsrf = ssInsert csrfKey
ssLookupCsrf = ssLookup csrfKey
ssForceInvalidate force = onSM (M.insert forceInvalidateKey (B8.pack $ show force))
-- | Apply a function to a 'SessionMap'.
onSM
:: (M.Map Text ByteString -> M.Map Text ByteString)
-> (SessionMap -> SessionMap)
onSM f = SessionMap . f . unSessionMap
----------------------------------------------------------------------
-- | A 'S.ISessionManager' using server-side sessions.
data ServerSessionManager s =
data ServerSessionManager sto =
ServerSessionManager
{ currentSession :: Maybe (SessionMap, SaveSessionToken)
{ currentSession :: Maybe (SessionData sto, SaveSessionToken sto)
-- ^ Field used for per-request caching of the session.
, state :: State s
, state :: State sto
-- ^ The core @serversession@ state.
, cookieName :: ByteString
-- ^ Cache of the cookie name as bytestring.
@ -70,26 +124,32 @@ data ServerSessionManager s =
} deriving (Typeable)
instance Storage s => S.ISessionManager (ServerSessionManager s) where
instance ( Storage sto
, SnapSession (SessionData sto)
) => S.ISessionManager (ServerSessionManager sto) where
load ssm@ServerSessionManager { currentSession = Just _ } =
-- Don't do anything if already loaded. Yeah, I know this is
-- strange, go figure.
return ssm
load ssm = do
-- Get session ID from cookie.
mcookie <- S.getCookie (cookieName ssm)
-- Load session from storage backend.
(sessionMap, saveSessionToken) <-
(data1, saveSessionToken) <-
liftIO $ loadSession (state ssm) (S.cookieValue <$> mcookie)
-- Add CSRF token if needed.
sessionMap' <-
data2 <-
maybe
(flip (M.insert csrfKey) sessionMap <$> N.nonce128url (nonceGen ssm))
(const $ return sessionMap)
(M.lookup csrfKey sessionMap)
(flip ssInsertCsrf data1 <$> N.nonce128urlT (nonceGen ssm))
(const $ return data1)
(ssLookupCsrf data1)
-- Good to go!
return ssm { currentSession = Just (sessionMap', saveSessionToken) }
return ssm { currentSession = Just (data2, saveSessionToken) }
commit ssm = do
-- Save session data to storage backend and set the cookie.
let Just (sessionMap, saveSessionToken) = currentSession ssm
msession <- liftIO $ saveSession (state ssm) saveSessionToken sessionMap
let Just (data_, saveSessionToken) = currentSession ssm
msession <- liftIO $ saveSession (state ssm) saveSessionToken data_
S.modifyResponse $ S.addResponseCookie $
maybe
(deleteCookie (state ssm) (cookieName ssm))
@ -100,60 +160,62 @@ instance Storage s => S.ISessionManager (ServerSessionManager s) where
-- Reset has no defined semantics. We invalidate the session
-- and clear its variables, which seems to be what the
-- current clientsession backend from the snap package does.
csrfToken <- N.nonce128url (nonceGen ssm)
let newSession = M.fromList [ (forceInvalidateKey, B8.pack $ show CurrentSessionId)
, (csrfKey, csrfToken) ]
csrfToken <- N.nonce128urlT (nonceGen ssm)
let newSession =
ssInsertCsrf csrfToken $
ssForceInvalidate CurrentSessionId $
emptySession
return $ modifyCurrentSession (const newSession) ssm
touch =
-- We always touch the session (if commit is called).
id
insert key value = modifyCurrentSession (M.insert key (TE.encodeUtf8 value))
insert key value = modifyCurrentSession (ssInsert key value)
lookup key =
-- Decoding will always succeed if the session is used only
-- from snap.
fmap TE.decodeUtf8 . M.lookup key . currentSessionMap "lookup"
ssLookup key . currentSessionMap "lookup"
delete key = modifyCurrentSession (M.delete key)
delete key = modifyCurrentSession (ssDelete key)
csrf =
-- Guaranteed to succeed since both load and reset add a
-- csrfKey to the session map.
fromMaybe (error "serversession-frontend-snap/csrf: never here") .
S.lookup csrfKey
ssLookupCsrf . currentSessionMap "csrf"
toList =
-- Remove the CSRF key from the list as the current
-- clientsession backend doesn't return it.
fmap (second TE.decodeUtf8) .
M.toList .
M.delete csrfKey .
currentSessionMap "toList"
toList = ssToList . currentSessionMap "toList"
-- | Get the current 'SessionMap' from 'currentSession' and
-- | Get the current 'SessionData' from 'currentSession' and
-- unwrap its @Just@. If it's @Nothing@, @error@ is called. We
-- expect 'load' to be called before any other 'ISessionManager'
-- method.
currentSessionMap :: String -> ServerSessionManager s -> SessionMap
currentSessionMap :: String -> ServerSessionManager sto -> SessionData sto
currentSessionMap fn ssm = maybe (error err) fst (currentSession ssm)
where err = "serversession-frontend-snap/" ++ fn ++
": currentSession is Nothing, did you call 'load'?"
-- | Modify the current session in any way.
modifyCurrentSession :: (SessionMap -> SessionMap) -> ServerSessionManager s -> ServerSessionManager s
modifyCurrentSession
:: (SessionData sto -> SessionData sto)
-> ServerSessionManager sto
-> ServerSessionManager sto
modifyCurrentSession f ssm = ssm { currentSession = fmap (first f) (currentSession ssm) }
----------------------------------------------------------------------
-- | Create a cookie for the given session.
--
-- The cookie expiration is set via 'nextExpires'. Note that
-- this is just an optimization, as the expiration is checked on
-- the server-side as well.
createCookie :: State s -> ByteString -> Session -> S.Cookie
createCookie :: State sto -> ByteString -> Session sess -> S.Cookie
createCookie st cookieNameBS session =
-- Generate a cookie with the final session ID.
S.Cookie
@ -176,7 +238,7 @@ createCookie st cookieNameBS session =
-- * If the user had a session cookie that was invalidated,
-- this will remove the invalid cookie from the client.
-- the server-side as well.
deleteCookie :: State s -> ByteString -> S.Cookie
deleteCookie :: State sto -> ByteString -> S.Cookie
deleteCookie st cookieNameBS =
S.Cookie
{ S.cookieName = cookieNameBS

View File

@ -34,7 +34,9 @@ library
Web.ServerSession.Frontend.Wai
Web.ServerSession.Frontend.Wai.Internal
extensions:
FlexibleContexts
OverloadedStrings
TypeFamilies
ghc-options: -Wall
source-repository head

View File

@ -17,6 +17,7 @@ module Web.ServerSession.Frontend.Wai
-- * Flexible interface
, sessionStore
, createCookieTemplate
, KeyValue(..)
-- * State configuration
, setCookieName
, setAuthKey

View File

@ -4,6 +4,7 @@ module Web.ServerSession.Frontend.Wai.Internal
( withServerSession
, sessionStore
, mkSession
, KeyValue(..)
, createCookieTemplate
, calculateMaxAge
, forceInvalidate
@ -34,10 +35,10 @@ import qualified Web.Cookie as C
-- that uses 'WS.withSession', 'createState', 'sessionStore',
-- 'getCookieName' and 'createCookieTemplate'.
withServerSession
:: (MonadIO m, MonadIO n, Storage s)
:: (MonadIO m, MonadIO n, Storage sto, SessionData sto ~ SessionMap)
=> V.Key (WS.Session m Text ByteString) -- ^ 'V.Vault' key to use when passing the session through.
-> (State s -> State s) -- ^ Set any options on the @serversession@ state.
-> s -- ^ Storage backend.
-> (State sto -> State sto) -- ^ Set any options on the @serversession@ state.
-> sto -- ^ Storage backend.
-> n W.Middleware
withServerSession key opts storage = liftIO $ do
st <- opts <$> createState storage
@ -56,32 +57,55 @@ withServerSession key opts storage = liftIO $ do
-- return an empty @ByteString@ when the empty session was not
-- saved.
sessionStore
:: (MonadIO m, Storage s)
=> State s -- ^ @serversession@ state, incl. storage backend.
-> WS.SessionStore m Text ByteString -- ^ @wai-session@ session store.
:: (MonadIO m, Storage sto, KeyValue (SessionData sto))
=> State sto -- ^ @serversession@ state, incl. storage backend.
-> WS.SessionStore m (Key (SessionData sto)) (Value (SessionData sto))
-- ^ @wai-session@ session store.
sessionStore state =
\mcookieVal -> do
(sessionMap, saveSessionToken) <- loadSession state mcookieVal
sessionRef <- I.newIORef sessionMap
(data1, saveSessionToken) <- loadSession state mcookieVal
sessionRef <- I.newIORef data1
let save = do
sessionMap' <- I.atomicModifyIORef' sessionRef $ \a -> (a, a)
msession <- saveSession state saveSessionToken sessionMap'
data2 <- I.atomicModifyIORef' sessionRef $ \a -> (a, a)
msession <- saveSession state saveSessionToken data2
return $ maybe "" (TE.encodeUtf8 . toPathPiece . sessionKey) msession
return (mkSession sessionRef, save)
-- | Build a 'WS.Session' from an 'I.IORef' containing the
-- session data.
mkSession :: MonadIO m => I.IORef SessionMap -> WS.Session m Text ByteString
mkSession :: (MonadIO m, KeyValue sess) => I.IORef sess -> WS.Session m (Key sess) (Value sess)
mkSession sessionRef =
-- We need to use atomicModifyIORef instead of readIORef
-- because latter may be reordered (cf. "Memory Model" on
-- Data.IORef's documentation).
( \k -> M.lookup k <$> liftIO (I.atomicModifyIORef' sessionRef $ \a -> (a, a))
, \k v -> liftIO (I.atomicModifyIORef' sessionRef $ flip (,) () . M.insert k v)
( \k -> kvLookup k <$> liftIO (I.atomicModifyIORef' sessionRef $ \a -> (a, a))
, \k v -> liftIO (I.atomicModifyIORef' sessionRef $ flip (,) () . kvInsert k v)
)
----------------------------------------------------------------------
-- | Class for session data types that can be used as key-value
-- stores.
class IsSessionData sess => KeyValue sess where
type Key sess :: *
type Value sess :: *
kvLookup :: Key sess -> sess -> Maybe (Value sess)
kvInsert :: Key sess -> Value sess -> sess -> sess
instance KeyValue SessionMap where
type Key SessionMap = Text
type Value SessionMap = ByteString
kvLookup k = M.lookup k . unSessionMap
kvInsert k v (SessionMap m) = SessionMap (M.insert k v m)
----------------------------------------------------------------------
-- | Create a cookie template given a state.
--
-- Since we don't have access to the 'Session', we can't fill the
@ -94,7 +118,7 @@ mkSession sessionRef =
-- Instead, we fill only the @Max-age@ field. It works fine for
-- modern browsers, but many don't support it and will treat the
-- cookie as non-persistent (notably IE 6, 7 and 8).
createCookieTemplate :: State s -> C.SetCookie
createCookieTemplate :: State sto -> C.SetCookie
createCookieTemplate state =
-- Generate a cookie with the final session ID.
def
@ -115,7 +139,7 @@ createCookieTemplate state =
-- * If no timeout is defined, the result is 10 years.
--
-- * Otherwise, the max age is set as the maximum timeout.
calculateMaxAge :: State s -> Maybe TI.DiffTime
calculateMaxAge :: State sto -> Maybe TI.DiffTime
calculateMaxAge st = do
guard (persistentCookies st)
return $ maybe (60*60*24*3652) realToFrac

View File

@ -18,6 +18,7 @@ library
build-depends:
base == 4.*
, bytestring
, containers
, cookie >= 0.4
, data-default
, path-pieces
@ -32,7 +33,9 @@ library
Web.ServerSession.Frontend.Yesod
Web.ServerSession.Frontend.Yesod.Internal
extensions:
FlexibleContexts
OverloadedStrings
TypeFamilies
ghc-options: -Wall
source-repository head

View File

@ -1,8 +1,19 @@
-- | Yesod server-side session support.
--
-- This package implements an Yesod @SessionBackend@, so it's a
-- drop-in replacement for the default @clientsession@.
--
-- Unfortunately, Yesod currently provides no way of accessing
-- the session other than via its own functions. If you want to
-- use a custom data type as your session data (instead of the
-- default @SessionMap@), it will have to implement
-- 'IsSessionMap' and you'll have to continue using Yesod's
-- session interface.
module Web.ServerSession.Frontend.Yesod
( -- * Using server-side sessions
simpleBackend
, backend
, IsSessionMap(..)
-- * Invalidating session IDs
, forceInvalidate
, ForceInvalidate(..)

View File

@ -3,6 +3,7 @@
module Web.ServerSession.Frontend.Yesod.Internal
( simpleBackend
, backend
, IsSessionMap(..)
, createCookie
, findSessionId
, forceInvalidate
@ -12,6 +13,7 @@ import Control.Monad (guard)
import Control.Monad.IO.Class (MonadIO)
import Data.ByteString (ByteString)
import Data.Default (def)
import Data.Text (Text)
import Web.PathPieces (toPathPiece)
import Web.ServerSession.Core
import Yesod.Core (MonadHandler)
@ -19,6 +21,7 @@ import Yesod.Core.Handler (setSessionBS)
import Yesod.Core.Types (Header(AddCookie), SessionBackend(..))
import qualified Data.ByteString.Char8 as B8
import qualified Data.Map as M
import qualified Data.Text.Encoding as TE
import qualified Data.Time as TI
import qualified Network.Wai as W
@ -53,42 +56,70 @@ import qualified Web.Cookie as C
-- . setAbsoluteTimeout (Just $ 60*60*24)
-- . setSecureCookies True
-- @
--
-- This is a simple version of 'backend' specialized for using
-- 'SessionMap' as 'SessionData'. If you want to use a different
-- session data type, please use 'backend' directly (tip: take a
-- peek at this function's source).
simpleBackend
:: (MonadIO m, Storage s)
=> (State s -> State s) -- ^ Set any options on the @serversession@ state.
-> s -- ^ Storage backend.
:: (MonadIO m, Storage sto, SessionData sto ~ SessionMap)
=> (State sto -> State sto) -- ^ Set any options on the @serversession@ state.
-> sto -- ^ Storage backend.
-> m (Maybe SessionBackend) -- ^ Yesod session backend (always @Just@).
simpleBackend opts s =
return . Just . backend . opts =<< createState s
-- | Construct the server-side session backend using the given
-- state.
-- state. This is a generalized version of 'simpleBackend'.
--
-- In order to use the Yesod frontend, you 'SessionData' needs to
-- implement 'IsSessionMap'.
backend
:: Storage s
=> State s -- ^ @serversession@ state, incl. storage backend.
:: (Storage sto, IsSessionMap (SessionData sto))
=> State sto -- ^ @serversession@ state, incl. storage backend.
-> SessionBackend -- ^ Yesod session backend.
backend state =
SessionBackend {
sbLoadSession = \req -> do
backend state = SessionBackend { sbLoadSession = load }
where
load req = do
let rawSessionId = findSessionId cookieNameBS req
(sessionMap, saveSessionToken) <- loadSession state rawSessionId
(data_, saveSessionToken) <- loadSession state rawSessionId
let save =
fmap ((:[]) . maybe (deleteCookie state cookieNameBS)
(createCookie state cookieNameBS)) .
saveSession state saveSessionToken
return (sessionMap, save)
}
where
saveSession state saveSessionToken .
fromSessionMap
return (toSessionMap data_, save)
cookieNameBS = TE.encodeUtf8 $ getCookieName state
----------------------------------------------------------------------
-- | Class for session data types meant to be used with the Yesod
-- frontend. The only session interface Yesod provides is via
-- session variables, so your data type needs to be convertible
-- from/to a 'M.Map' of 'Text' to 'ByteString'.
class IsSessionMap sess where
toSessionMap :: sess -> M.Map Text ByteString
fromSessionMap :: M.Map Text ByteString -> sess
instance IsSessionMap SessionMap where
toSessionMap = unSessionMap
fromSessionMap = SessionMap
----------------------------------------------------------------------
-- | Create a cookie for the given session.
--
-- The cookie expiration is set via 'nextExpires'. Note that
-- this is just an optimization, as the expiration is checked on
-- the server-side as well.
createCookie :: State s -> ByteString -> Session -> Header
createCookie :: State sto -> ByteString -> Session sess -> Header
createCookie state cookieNameBS session =
-- Generate a cookie with the final session ID.
AddCookie def
@ -110,7 +141,7 @@ createCookie state cookieNameBS session =
--
-- * If the user had a session cookie that was invalidated,
-- this will remove the invalid cookie from the client.
deleteCookie :: State s -> ByteString -> Header
deleteCookie :: State sto -> ByteString -> Header
deleteCookie state cookieNameBS =
AddCookie def
{ C.setCookieName = cookieNameBS

View File

@ -37,7 +37,9 @@ library
OverloadedStrings
RecordWildCards
ScopedTypeVariables
StandaloneDeriving
TypeFamilies
UndecidableInstances
ghc-options: -Wall
@ -53,9 +55,12 @@ test-suite tests
, serversession
extensions:
DeriveDataTypeable
FlexibleContexts
OverloadedStrings
StandaloneDeriving
TupleSections
TypeFamilies
UndecidableInstances
main-is: Main.hs
ghc-options: -Wall -threaded -with-rtsopts=-N

View File

@ -6,9 +6,11 @@ module Web.ServerSession.Core
, Session(..)
, Storage(..)
, StorageException(..)
, IsSessionData(..)
, DecomposedSession(..)
-- * For serversession frontends
, SessionMap
, SessionMap(..)
, State
, createState
, getCookieName

View File

@ -1,13 +1,21 @@
-- | Internal module exposing the guts of the package. Use at
-- your own risk. No API stability guarantees apply.
--
-- @UndecidableInstances@ is required in order to implement @Eq@,
-- @Ord@, @Show@, etc. on data types that have @Decomposed@
-- fields, and should be fairly safe.
module Web.ServerSession.Core.Internal
( SessionId(..)
, checkSessionId
, generateSessionId
, SessionMap
, AuthId
, Session(..)
, SessionMap(..)
, IsSessionData(..)
, DecomposedSession(..)
, Storage(..)
, StorageException(..)
@ -32,10 +40,7 @@ module Web.ServerSession.Core.Internal
, saveSession
, SaveSessionToken(..)
, invalidateIfNeeded
, DecomposedSession(..)
, decomposeSession
, saveSessionOnDb
, toSessionMap
, forceInvalidateKey
, ForceInvalidate(..)
) where
@ -64,7 +69,8 @@ import qualified Data.Text.Encoding as TE
-- | The ID of a session. Always 18 bytes base64url-encoded as
-- 24 characters.
-- 24 characters. The @sess@ type variable is a phantom type for
-- the session data type this session ID points to.
--
-- Implementation notes:
--
@ -72,24 +78,24 @@ import qualified Data.Text.Encoding as TE
--
-- * Use 'generateSessionId' for securely generating new
-- session IDs.
newtype SessionId = S { unS :: Text }
newtype SessionId sess = S { unS :: Text }
deriving (Eq, Ord, Show, Read, Typeable)
-- | Sanity checks input on 'fromPathPiece' (untrusted input).
instance PathPiece SessionId where
instance PathPiece (SessionId sess) where
toPathPiece = unS
fromPathPiece = checkSessionId
instance A.FromJSON SessionId where
instance A.FromJSON (SessionId sess) where
parseJSON = fmap S . A.parseJSON
instance A.ToJSON SessionId where
instance A.ToJSON (SessionId sess) where
toJSON = A.toJSON . unS
-- | (Internal) Check that the given text is a base64url-encoded
-- representation of 18 bytes.
checkSessionId :: Text -> Maybe SessionId
checkSessionId :: Text -> Maybe (SessionId sess)
checkSessionId text = do
guard (T.length text == 24)
let bs = TE.encodeUtf8 text
@ -99,23 +105,13 @@ checkSessionId text = do
-- | Securely generate a new SessionId.
generateSessionId :: N.Generator -> IO SessionId
generateSessionId :: N.Generator -> IO (SessionId sess)
generateSessionId = fmap S . N.nonce128urlT
----------------------------------------------------------------------
-- | A session map.
--
-- This is the representation of a session used by the
-- @serversession@ family of packages, transferring data between
-- this core package and frontend packages. Serversession
-- storage backend packages should use 'Session'. End users
-- should use their web framework's support for sessions.
type SessionMap = M.Map Text ByteString
-- | Value of the 'authKey' session key.
type AuthId = ByteString
@ -124,49 +120,179 @@ type AuthId = ByteString
--
-- This representation is used by the @serversession@ family of
-- packages, transferring data between this core package and
-- storage backend packages. Serversession frontend packages
-- should use 'SessionMap'. End users should use their web
-- framework's support for sessions.
data Session =
-- storage backend packages. The @sess@ type variable describes
-- the session data type.
data Session sess =
Session
{ sessionKey :: SessionId
{ sessionKey :: SessionId sess
-- ^ Session ID, primary key.
, sessionAuthId :: Maybe AuthId
-- ^ Value of 'authKey' session key, separate from the rest.
, sessionData :: SessionMap
, sessionData :: Decomposed sess
-- ^ Rest of the session data.
, sessionCreatedAt :: UTCTime
-- ^ When this session was created.
, sessionAccessedAt :: UTCTime
-- ^ When this session was last accessed.
} deriving (Eq, Ord, Show, Typeable)
} deriving (Typeable)
deriving instance Eq (Decomposed sess) => Eq (Session sess)
deriving instance Ord (Decomposed sess) => Ord (Session sess)
deriving instance Show (Decomposed sess) => Show (Session sess)
-- | A storage backend for server-side sessions.
class MonadIO (TransactionM s) => Storage s where
-- | A @newtype@ for a common session map.
--
-- This is a common representation of a session. Although
-- @serversession@ has generalized session data types, you can
-- use this one if you don't want to worry about it. We strive
-- to support this session data type on all frontends and storage
-- backends.
newtype SessionMap =
SessionMap { unSessionMap :: M.Map Text ByteString }
deriving (Eq, Ord, Show, Read, Typeable)
----------------------------------------------------------------------
-- | Class for data types to be used as session data
-- (cf. 'sessionData', 'SessionData').
--
-- The @Show@ constrain is needed for 'StorageException'.
class ( Show (Decomposed sess)
, Typeable (Decomposed sess)
, Typeable sess
) => IsSessionData sess where
-- | The type of the session data after being decomposed. This
-- may be the same as @sess@.
type Decomposed sess :: *
-- | Empty session data.
emptySession :: sess
-- | Decompose session data into:
--
-- * The auth ID of the logged in user (cf. 'setAuthKey',
-- 'dsAuthId').
--
-- * If the session is being forced to be invalidated
-- (cf. 'forceInvalidateKey', 'ForceInvalidate').
--
-- * The rest of the session data (cf. 'Decomposed').
decomposeSession
:: Text -- ^ The auth key (cf. 'setAuthKey').
-> sess -- ^ Session data to be decomposed.
-> DecomposedSession sess -- ^ Decomposed session data.
-- | Recompose a decomposed session again into a proper @sess@.
recomposeSession
:: Text -- ^ The auth key (cf. 'setAuthKey').
-> Maybe AuthId -- ^ The @AuthId@, if any.
-> Decomposed sess -- ^ Decomposed session data to be recomposed.
-> sess -- ^ Recomposed session data.
-- | Returns @True@ when both session datas are to be
-- considered the same.
--
-- This is used to optimize storage calls
-- (cf. 'setTimeoutResolution'). Always returning @False@ will
-- disable the optimization but won't have any other adverse
-- effects.
--
-- For data types implementing 'Eq', this is usually a good
-- implementation:
--
-- @
-- isSameDecomposed _ = (==)
-- @
isSameDecomposed :: proxy sess -> Decomposed sess -> Decomposed sess -> Bool
-- | Returns @True@ if the decomposed session data is to be
-- considered @empty@.
--
-- This is used to avoid storing empty session data if at all
-- possible. Always returning @False@ will disable the
-- optimization but won't have any other adverse effects.
isDecomposedEmpty :: proxy sess -> Decomposed sess -> Bool
-- | A 'SessionMap' decomposes into a 'SessionMap' minus the keys
-- that were removed. The auth key is added back when
-- recomposing.
instance IsSessionData SessionMap where
type Decomposed SessionMap = SessionMap
emptySession = SessionMap M.empty
isSameDecomposed _ = (==)
decomposeSession authKey_ (SessionMap sm1) =
let (authId, sm2) = M.updateLookupWithKey (\_ _ -> Nothing) authKey_ sm1
(force, sm3) = M.updateLookupWithKey (\_ _ -> Nothing) forceInvalidateKey sm2
in DecomposedSession
{ dsAuthId = authId
, dsForceInvalidate = maybe DoNotForceInvalidate (read . B8.unpack) force
, dsDecomposed = SessionMap sm3 }
recomposeSession authKey_ mauthId (SessionMap sm) =
SessionMap $ maybe id (M.insert authKey_) mauthId sm
isDecomposedEmpty _ = M.null . unSessionMap
-- | A session data type @sess@ with its special variables taken apart.
data DecomposedSession sess =
DecomposedSession
{ dsAuthId :: !(Maybe ByteString)
, dsForceInvalidate :: !ForceInvalidate
, dsDecomposed :: !(Decomposed sess)
} deriving (Typeable)
deriving instance Eq (Decomposed sess) => Eq (DecomposedSession sess)
deriving instance Ord (Decomposed sess) => Ord (DecomposedSession sess)
deriving instance Show (Decomposed sess) => Show (DecomposedSession sess)
----------------------------------------------------------------------
-- | A storage backend @sto@ for server-side sessions. The
-- @sess@ session data type and\/or its 'Decomposed' version may
-- be constrained depending on the storage backend capabilities.
class ( Typeable sto
, MonadIO (TransactionM sto)
, IsSessionData (SessionData sto)
) => Storage sto where
-- | The session data type used by this storage.
type SessionData sto :: *
-- | Monad where transactions happen for this backend.
-- We do not require transactions to be ACID.
type TransactionM s :: * -> *
type TransactionM sto :: * -> *
-- | Run a transaction on the IO monad.
runTransactionM :: s -> TransactionM s a -> IO a
runTransactionM :: sto -> TransactionM sto a -> IO a
-- | Get the session for the given session ID. Returns
-- @Nothing@ if the session is not found.
getSession :: s -> SessionId -> TransactionM s (Maybe Session)
getSession
:: sto
-> SessionId (SessionData sto)
-> TransactionM sto (Maybe (Session (SessionData sto)))
-- | Delete the session with given session ID. Does not do
-- anything if the session is not found.
deleteSession :: s -> SessionId -> TransactionM s ()
deleteSession :: sto -> SessionId (SessionData sto) -> TransactionM sto ()
-- | Delete all sessions of the given auth ID. Does not do
-- anything if there are no sessions of the given auth ID.
deleteAllSessionsOfAuthId :: s -> AuthId -> TransactionM s ()
deleteAllSessionsOfAuthId :: sto -> AuthId -> TransactionM sto ()
-- | Insert a new session. Throws 'SessionAlreadyExists' if
-- there already exists a session with the same session ID (we
-- only call this method after generating a fresh session ID).
insertSession :: s -> Session -> TransactionM s ()
insertSession :: sto -> Session (SessionData sto) -> TransactionM sto ()
-- | Replace the contents of a session. Throws
-- 'SessionDoesNotExist' if there is no session with the given
@ -211,29 +337,32 @@ class MonadIO (TransactionM s) => Storage s where
-- Most of the time this discussion does not matter.
-- Invalidations usually occur at times where only one request
-- is flying.
replaceSession :: s -> Session -> TransactionM s ()
replaceSession :: sto -> Session (SessionData sto) -> TransactionM sto ()
-- | Common exceptions that may be thrown by any storage.
data StorageException =
data StorageException sto =
-- | Exception thrown by 'insertSession' whenever a session
-- with same ID already exists.
SessionAlreadyExists
{ seExistingSession :: Session
, seNewSession :: Session }
{ seExistingSession :: Session (SessionData sto)
, seNewSession :: Session (SessionData sto) }
-- | Exception thrown by 'replaceSession' whenever trying to
-- replace a session that is not present on the storage.
| SessionDoesNotExist
{ seNewSession :: Session }
deriving (Show, Typeable)
{ seNewSession :: Session (SessionData sto) }
deriving (Typeable)
instance E.Exception StorageException where
deriving instance Eq (Decomposed (SessionData sto)) => Eq (StorageException sto)
deriving instance Ord (Decomposed (SessionData sto)) => Ord (StorageException sto)
deriving instance Show (Decomposed (SessionData sto)) => Show (StorageException sto)
instance Storage sto => E.Exception (StorageException sto) where
----------------------------------------------------------------------
-- TODO: delete expired sessions.
-- | The server-side session backend needs to maintain some state
@ -256,10 +385,10 @@ instance E.Exception StorageException where
-- and/or secure ('setSecureCookies').
--
-- Create a new 'State' using 'createState'.
data State s =
data State sto =
State
{ generator :: !N.Generator
, storage :: !s
, storage :: !sto
, cookieName :: !Text
, authKey :: !Text
, idleTimeout :: !(Maybe NominalDiffTime)
@ -273,7 +402,7 @@ data State s =
-- | Create a new 'State' for the server-side session backend
-- using the given storage backend.
createState :: MonadIO m => s -> m (State s)
createState :: MonadIO m => sto -> m (State sto)
createState sto = do
gen <- N.new
return State
@ -294,13 +423,21 @@ createState sto = do
-- Defaults to \"JSESSIONID\", which is a generic cookie name
-- used by many frameworks thus making it harder to fingerprint
-- this implementation.
setCookieName :: Text -> State s -> State s
setCookieName :: Text -> State sto -> State sto
setCookieName val state = state { cookieName = val }
-- | Set the name of the session variable that keeps track of the
-- logged user. Defaults to \"_ID\" (used by @yesod-auth@).
setAuthKey :: Text -> State s -> State s
-- logged user.
--
-- This setting is used by session data types that are
-- @Map@-alike, using a @lookup@ function. However, the
-- 'IsSessionData' instance of a session data type may choose not
-- to use it. For example, if you implemented a custom data
-- type, you could return the @AuthId@ without needing a lookup.
--
-- Defaults to \"_ID\" (used by @yesod-auth@).
setAuthKey :: Text -> State sto -> State sto
setAuthKey val state = state { authKey = val }
@ -318,7 +455,7 @@ setAuthKey val state = state { authKey = val }
-- (<https://www.owasp.org/index.php/Session_Management_Cheat_Sheet#Idle_Timeout Source>)
--
-- Defaults to 7 days.
setIdleTimeout :: Maybe NominalDiffTime -> State s -> State s
setIdleTimeout :: Maybe NominalDiffTime -> State sto -> State sto
setIdleTimeout (Just d) _ | d <= 0 = error "serversession/setIdleTimeout: Timeout should be positive."
setIdleTimeout val state = state { idleTimeout = val }
@ -338,7 +475,7 @@ setIdleTimeout val state = state { idleTimeout = val }
-- (<https://www.owasp.org/index.php/Session_Management_Cheat_Sheet#Absolute_Timeout Source>)
--
-- Defaults to 60 days.
setAbsoluteTimeout :: Maybe NominalDiffTime -> State s -> State s
setAbsoluteTimeout :: Maybe NominalDiffTime -> State sto -> State sto
setAbsoluteTimeout (Just d) _ | d <= 0 = error "serversession/setAbsoluteTimeout: Timeout should be positive."
setAbsoluteTimeout val state = state { absoluteTimeout = val }
@ -368,7 +505,7 @@ setAbsoluteTimeout val state = state { absoluteTimeout = val }
-- becomes disabled and the session will always be updated.
--
-- Defaults to 10 minutes.
setTimeoutResolution :: Maybe NominalDiffTime -> State s -> State s
setTimeoutResolution :: Maybe NominalDiffTime -> State sto -> State sto
setTimeoutResolution (Just d) _ | d <= 0 = error "serversession/setTimeoutResolution: Resolution should be positive."
setTimeoutResolution val state = state { timeoutResolution = val }
@ -382,7 +519,7 @@ setTimeoutResolution val state = state { timeoutResolution = val }
-- cookie is set to expire in 10 years.
--
-- Defaults to @True@.
setPersistentCookies :: Bool -> State s -> State s
setPersistentCookies :: Bool -> State sto -> State sto
setPersistentCookies val state = state { persistentCookies = val }
@ -393,7 +530,7 @@ setPersistentCookies val state = state { persistentCookies = val }
-- It's highly recommended to set this attribute to @True@.
--
-- Defaults to @True@.
setHttpOnlyCookies :: Bool -> State s -> State s
setHttpOnlyCookies :: Bool -> State sto -> State sto
setHttpOnlyCookies val state = state { httpOnlyCookies = val }
@ -405,22 +542,22 @@ setHttpOnlyCookies val state = state { httpOnlyCookies = val }
-- @False@.
--
-- Defaults to @False@.
setSecureCookies :: Bool -> State s -> State s
setSecureCookies :: Bool -> State sto -> State sto
setSecureCookies val state = state { secureCookies = val }
-- | Cf. 'setCookieName'.
getCookieName :: State s -> Text
getCookieName :: State sto -> Text
getCookieName = cookieName
-- | Cf. 'setHttpOnlyCookies'.
getHttpOnlyCookies :: State s -> Bool
getHttpOnlyCookies :: State sto -> Bool
getHttpOnlyCookies = httpOnlyCookies
-- | Cf. 'setSecureCookies'.
getSecureCookies :: State s -> Bool
getSecureCookies :: State sto -> Bool
getSecureCookies = secureCookies
@ -432,25 +569,33 @@ getSecureCookies = secureCookies
--
-- Returns:
--
-- * The 'SessionMap' to be used by the frontend as the current
-- session's value.
-- * The session data @sess@ to be used by the frontend as the
-- current session's value.
--
-- * Information to be passed back to 'saveSession' on the end
-- of the request in order to save the session.
loadSession :: Storage s => State s -> Maybe ByteString -> IO (SessionMap, SaveSessionToken)
loadSession
:: Storage sto
=> State sto
-> Maybe ByteString
-> IO (SessionData sto, SaveSessionToken sto)
loadSession state mcookieVal = do
now <- getCurrentTime
let maybeInputId = mcookieVal >>= fromPathPiece . TE.decodeUtf8
get = runTransactionM (storage state) . getSession (storage state)
checkedGet = fmap (>>= checkExpired now state) . get
maybeInput <- maybe (return Nothing) checkedGet maybeInputId
let inputSessionMap = maybe M.empty (toSessionMap state) maybeInput
return (inputSessionMap, SaveSessionToken maybeInput now)
let inputData =
maybe
emptySession
(\s -> recomposeSession (authKey state) (sessionAuthId s) (sessionData s))
maybeInput
return (inputData, SaveSessionToken maybeInput now)
-- | Check if a session @s@ has expired. Returns the @Just s@ if
-- not expired, or @Nothing@ if expired.
checkExpired :: UTCTime {-^ Now. -} -> State s -> Session -> Maybe Session
checkExpired :: UTCTime {-^ Now. -} -> State sto -> Session sess -> Maybe (Session sess)
checkExpired now state session =
let expired = maybe False (< now) (nextExpires state session)
in guard (not expired) >> return session
@ -460,7 +605,7 @@ checkExpired now state session =
-- will expire assuming that it sees no activity until then.
-- Returns @Nothing@ iff the state does not have any expirations
-- set to @Just@.
nextExpires :: State s -> Session -> Maybe UTCTime
nextExpires :: State sto -> Session sess -> Maybe UTCTime
nextExpires State {..} Session {..} =
let viaIdle = flip addUTCTime sessionAccessedAt <$> idleTimeout
viaAbsolute = flip addUTCTime sessionCreatedAt <$> absoluteTimeout
@ -471,7 +616,7 @@ nextExpires State {..} Session {..} =
-- | Calculate the date that should be used for the cookie's
-- \"Expires\" field.
cookieExpires :: State s -> Session -> Maybe UTCTime
cookieExpires :: State sto -> Session sess -> Maybe UTCTime
cookieExpires State {..} _ | not persistentCookies = Nothing
cookieExpires state session =
Just $ fromMaybe tenYearsFromNow $ nextExpires state session
@ -481,9 +626,13 @@ cookieExpires state session =
-- | Opaque token containing the necessary information for
-- 'saveSession' to save the session.
data SaveSessionToken =
SaveSessionToken (Maybe Session) UTCTime
deriving (Eq, Show, Typeable)
data SaveSessionToken sto =
SaveSessionToken (Maybe (Session (SessionData sto))) UTCTime
deriving (Typeable)
deriving instance Eq (Decomposed (SessionData sto)) => Eq (SaveSessionToken sto)
deriving instance Ord (Decomposed (SessionData sto)) => Ord (SaveSessionToken sto)
deriving instance Show (Decomposed (SessionData sto)) => Show (SaveSessionToken sto)
-- | Save the session on the storage backend. A
@ -496,12 +645,17 @@ data SaveSessionToken =
-- and clear every other sesssion variable, then 'saveSession'
-- will invalidate the older session but will avoid creating a
-- new, empty one.
saveSession :: Storage s => State s -> SaveSessionToken -> SessionMap -> IO (Maybe Session)
saveSession state (SaveSessionToken maybeInput now) wholeOutputSessionMap =
saveSession
:: Storage sto
=> State sto
-> SaveSessionToken sto
-> SessionData sto
-> IO (Maybe (Session (SessionData sto)))
saveSession state (SaveSessionToken maybeInput now) outputData =
runTransactionM (storage state) $ do
let decomposedSessionMap = decomposeSession state wholeOutputSessionMap
newMaybeInput <- invalidateIfNeeded state maybeInput decomposedSessionMap
saveSessionOnDb state now newMaybeInput decomposedSessionMap
let outputDecomp = decomposeSession (authKey state) outputData
newMaybeInput <- invalidateIfNeeded state maybeInput outputDecomp
saveSessionOnDb state now newMaybeInput outputDecomp
-- | Invalidates an old session ID if needed. Returns the
@ -512,11 +666,11 @@ saveSession state (SaveSessionToken maybeInput now) wholeOutputSessionMap =
-- fixation attacks. We also invalidate when asked to via
-- 'forceInvalidate'.
invalidateIfNeeded
:: Storage s
=> State s
-> Maybe Session
-> DecomposedSession
-> TransactionM s (Maybe Session)
:: Storage sto
=> State sto
-> Maybe (Session (SessionData sto))
-> DecomposedSession (SessionData sto)
-> TransactionM sto (Maybe (Session (SessionData sto)))
invalidateIfNeeded state maybeInput DecomposedSession {..} = do
-- Decide which action to take.
-- "invalidateOthers implies invalidateCurrent" should be true below.
@ -531,26 +685,6 @@ invalidateIfNeeded state maybeInput DecomposedSession {..} = do
return $ guard (not invalidateCurrent) >> maybeInput
-- | A 'SessionMap' with its special variables taken apart.
data DecomposedSession =
DecomposedSession
{ dsAuthId :: !(Maybe ByteString)
, dsForceInvalidate :: !ForceInvalidate
, dsSessionMap :: !SessionMap
} deriving (Eq, Show, Typeable)
-- | Decompose a session (see 'DecomposedSession').
decomposeSession :: State s -> SessionMap -> DecomposedSession
decomposeSession state sm1 =
let (authId, sm2) = M.updateLookupWithKey (\_ _ -> Nothing) (authKey state) sm1
(force, sm3) = M.updateLookupWithKey (\_ _ -> Nothing) forceInvalidateKey sm2
in DecomposedSession
{ dsAuthId = authId
, dsForceInvalidate = maybe DoNotForceInvalidate (read . B8.unpack) force
, dsSessionMap = sm3 }
-- | Save a session on the database. If an old session is
-- supplied, it is replaced, otherwise a new session is
-- generated. If the session is empty, it is not saved and
@ -558,24 +692,30 @@ decomposeSession state sm1 =
-- is applied (cf. 'setTimeoutResolution'), the old session is
-- returned and no update is made.
saveSessionOnDb
:: Storage s
=> State s
-> UTCTime -- ^ Now.
-> Maybe Session -- ^ The old session, if any.
-> DecomposedSession -- ^ The session data to be saved.
-> TransactionM s (Maybe Session) -- ^ Copy of saved session.
:: forall sto. Storage sto
=> State sto
-> UTCTime -- ^ Now.
-> Maybe (Session (SessionData sto)) -- ^ The old session, if any.
-> DecomposedSession (SessionData sto) -- ^ The session data to be saved.
-> TransactionM sto (Maybe (Session (SessionData sto))) -- ^ Copy of saved session.
saveSessionOnDb _ _ Nothing (DecomposedSession Nothing _ m)
-- Return Nothing without doing anything whenever the session
-- is empty (including auth ID) and there was no prior session.
| M.null m = return Nothing
saveSessionOnDb State { timeoutResolution = Just res } now (Just old) (DecomposedSession authId _ sessionMap)
| isDecomposedEmpty proxy m = return Nothing
where
proxy :: Maybe (SessionData sto)
proxy = Nothing
saveSessionOnDb State { timeoutResolution = Just res } now (Just old) (DecomposedSession authId _ newSession)
-- If the data is the same and the old access time is within
-- the timeout resolution, just return the old session without
-- doing anything else.
| sessionData old == sessionMap &&
sessionAuthId old == authId &&
| sessionAuthId old == authId &&
isSameDecomposed proxy (sessionData old) newSession &&
abs (diffUTCTime now (sessionAccessedAt old)) < res =
return (Just old)
where
proxy :: Maybe (SessionData sto)
proxy = Nothing
saveSessionOnDb state now maybeInput DecomposedSession {..} = do
-- Generate properties if needed or take them from previous
-- saved session.
@ -593,7 +733,7 @@ saveSessionOnDb state now maybeInput DecomposedSession {..} = do
let session = Session
{ sessionKey = key
, sessionAuthId = dsAuthId
, sessionData = dsSessionMap
, sessionData = dsDecomposed
, sessionCreatedAt = createdAt
, sessionAccessedAt = now
}
@ -601,12 +741,6 @@ saveSessionOnDb state now maybeInput DecomposedSession {..} = do
return (Just session)
-- | Create a 'SessionMap' from a 'Session'.
toSessionMap :: State s -> Session -> SessionMap
toSessionMap state Session {..} =
maybe id (M.insert $ authKey state) sessionAuthId sessionData
-- | The session key used to signal that the session ID should be
-- invalidated.
forceInvalidateKey :: Text

View File

@ -18,7 +18,7 @@ import qualified Data.Text as T
import qualified Data.Time as TI
-- | Execute all storage tests.
-- | Execute all storage tests using 'SessionMap'.
--
-- This function is meant to be used with @hspec@. However, we
-- don't want to depend on @hspec@, so it takes the relevant
@ -42,8 +42,8 @@ import qualified Data.Time as TI
-- sequentially in order to reduce the peak memory usage of the
-- test suite.
allStorageTests
:: forall m s. (Monad m, Storage s)
=> s -- ^ Storage backend.
:: forall m sto. (Monad m, Storage sto, SessionData sto ~ SessionMap)
=> sto -- ^ Storage backend.
-> (String -> IO () -> m ()) -- ^ @hspec@'s it.
-> (forall a. IO a -> m a) -- ^ @hspec@'s runIO.
-> (m () -> m ()) -- ^ @hspec@'s parallel
@ -52,7 +52,7 @@ allStorageTests
-> (forall a e. Exception e => IO a -> (e -> Bool) -> IO ()) -- ^ @hspec@'s shouldThrow.
-> m ()
allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = do
let run :: forall a. TransactionM s a -> IO a
let run :: forall a. TransactionM sto a -> IO a
run = runTransactionM storage
gen <- runIO N.new
@ -131,7 +131,8 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d
run (insertSession storage s1)
run (getSession storage sid) `shouldReturn` Just s1
run (insertSession storage s3) `shouldThrow`
(\(SessionAlreadyExists s1' s3') -> s1 == s1' && s3 == s3')
(\(SessionAlreadyExists s1' s3' :: StorageException sto) ->
s1 == s1' && s3 == s3')
run (getSession storage sid) `shouldReturn` Just s1
-- replaceSession
@ -153,7 +154,8 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d
s <- generateSession gen HasAuthId
let sid = sessionKey s
run (getSession storage sid) `shouldReturn` Nothing
run (replaceSession storage s) `shouldThrow` (\(SessionDoesNotExist s') -> s == s')
run (replaceSession storage s) `shouldThrow`
(\(SessionDoesNotExist s' :: StorageException sto) -> s == s')
run (getSession storage sid) `shouldReturn` Nothing
run (insertSession storage s)
run (getSession storage sid) `shouldReturn` Just s
@ -169,11 +171,11 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d
let session = Session
{ sessionKey = sid
, sessionAuthId = Nothing
, sessionData = M.fromList vals
, sessionData = SessionMap $ M.fromList vals
, sessionCreatedAt = now
, sessionAccessedAt = now
}
ver2 = session { sessionData = M.empty }
ver2 = session { sessionData = SessionMap M.empty }
run (getSession storage sid) `shouldReturn` Nothing
run (insertSession storage session)
run (getSession storage sid) `shouldReturn` (Just session)
@ -204,7 +206,7 @@ generateAuthId = N.nonce128url
-- | Generate a random session for our tests.
generateSession :: N.Generator -> HasAuthId -> IO Session
generateSession :: N.Generator -> HasAuthId -> IO (Session SessionMap)
generateSession gen hasAuthId = do
sid <- generateSessionId gen
authId <-
@ -219,7 +221,7 @@ generateSession gen hasAuthId = do
return Session
{ sessionKey = sid
, sessionAuthId = authId
, sessionData = data_
, sessionData = SessionMap data_
, sessionCreatedAt = TI.addUTCTime (-1000) now
, sessionAccessedAt = now
}

View File

@ -58,7 +58,7 @@ main = hspec $ parallel $ do
return $ fromPathPiece (toPathPiece sid) Q.=== Just sid
it "does not accept as valid some example invalid session IDs" $ do
let parse = fromPathPiece :: T.Text -> Maybe SessionId
let parse = fromPathPiece :: T.Text -> Maybe (SessionId SessionMap)
parse "" `shouldBe` Nothing
parse "123456789-123456789-123" `shouldBe` Nothing
parse "123456789-123456789-12345" `shouldBe` Nothing
@ -95,7 +95,7 @@ main = hspec $ parallel $ do
let point1 = 0.1 {- second -} :: Double
now <- TI.getCurrentTime
abs (realToFrac $ TI.diffUTCTime now time) `shouldSatisfy` (< point1)
sessionMap `shouldBe` M.empty
sessionMap `shouldBe` TNTSessionData
msession `shouldSatisfy` isNothing
it "returns empty session and token when the session ID cookie is not present" $ do
@ -119,7 +119,7 @@ main = hspec $ parallel $ do
let session = Session
{ sessionKey = S "123456789-123456789-1234"
, sessionAuthId = Just authId
, sessionData = M.fromList [("a", "b"), ("c", "d")]
, sessionData = mkSessionMap [("a", "b"), ("c", "d")]
, sessionCreatedAt = TI.addUTCTime (-10) fakenow
, sessionAccessedAt = TI.addUTCTime (-5) fakenow
}
@ -127,7 +127,7 @@ main = hspec $ parallel $ do
st <- createState =<< prepareMockStorage [session]
(retSessionMap, SaveSessionToken msession _now) <-
loadSession st (Just $ B8.pack $ T.unpack $ unS $ sessionKey session)
retSessionMap `shouldBe` M.insert (authKey st) authId (sessionData session)
retSessionMap `shouldBe` onSM (M.insert (authKey st) authId) (sessionData session)
msession `shouldBe` Just session
describe "checkExpired" $ do
@ -214,31 +214,31 @@ main = hspec $ parallel $ do
it "works for a complex example" $ do
sto <- emptyMockStorage
st <- createState sto
saveSession st (SaveSessionToken Nothing fakenow) M.empty `shouldReturn` Nothing
saveSession st (SaveSessionToken Nothing fakenow) emptySM `shouldReturn` Nothing
getMockOperations sto `shouldReturn` []
let m1 = M.singleton "a" "b"
let m1 = mkSessionMap [("a", "b")]
Just session1 <- saveSession st (SaveSessionToken Nothing fakenow) m1
sessionAuthId session1 `shouldBe` Nothing
sessionData session1 `shouldBe` m1
getMockOperations sto `shouldReturn` [InsertSession session1]
let m2 = M.insert (authKey st) "john" m1
let m2 = onSM (M.insert (authKey st) "john") m1
Just session2 <- saveSession st (SaveSessionToken (Just session1) fakenow) m2
sessionAuthId session2 `shouldBe` Just "john"
sessionData session2 `shouldBe` m1
sessionKey session2 == sessionKey session1 `shouldBe` False
getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session1), InsertSession session2]
let m3 = M.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser) m2
let m3 = onSM (M.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser)) m2
Just session3 <- saveSession st (SaveSessionToken (Just session2) fakenow) m3
session3 `shouldBe` session2 { sessionKey = sessionKey session3 }
getMockOperations sto `shouldReturn`
[DeleteSession (sessionKey session2), DeleteAllSessionsOfAuthId "john", InsertSession session3]
let m4 = M.insert "x" "y" m2
let m4 = onSM (M.insert "x" "y") m2
Just session4 <- saveSession st (SaveSessionToken (Just session3) fakenow) m4
session4 `shouldBe` session3 { sessionData = M.delete (authKey st) m4 }
session4 `shouldBe` session3 { sessionData = onSM (M.delete (authKey st)) m4 }
getMockOperations sto `shouldReturn` [ReplaceSession session4]
Just session5 <- saveSession st (SaveSessionToken (Just session4) (TI.addUTCTime 10 fakenow)) m4
@ -250,18 +250,18 @@ main = hspec $ parallel $ do
let oldSession = Session
{ sessionKey = S "123456789-123456789-1234"
, sessionAuthId = authId
, sessionData = M.empty
, sessionData = emptySM
, sessionCreatedAt = TI.addUTCTime (-10) fakenow
, sessionAccessedAt = TI.addUTCTime (-5) fakenow }
sto <- prepareMockStorage [oldSession]
st <- createState sto
return (oldSession, sto, st)
return (oldSession, sto :: MockStorage SessionMap, st)
allEdges = let x = [Nothing, Just "john", Just "jane"] in (,) <$> x <*> x
it "does not invalidate when not changing auth ID nor explicitly requesting" $ do
forM_ [Nothing, Just "john"] $ \authId -> do
(session, sto, st) <- prepareInvalidateIfNeeded authId
let d = DecomposedSession authId DoNotForceInvalidate M.empty
let d = DecomposedSession authId DoNotForceInvalidate emptySM
invalidateIfNeeded st (Just session) d `shouldReturn` Just session
getMockOperations sto `shouldReturn` []
@ -270,21 +270,21 @@ main = hspec $ parallel $ do
, (Just "admin", Nothing)
, (Nothing, Just "joe") ] $ \edgeTransition -> do
(session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition)
let d = DecomposedSession (snd edgeTransition) DoNotForceInvalidate M.empty
let d = DecomposedSession (snd edgeTransition) DoNotForceInvalidate emptySM
invalidateIfNeeded st (Just session) d `shouldReturn` Nothing
getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session)]
it "invalidates the current session when CurrentSessionId is forced" $ do
forM_ allEdges $ \edgeTransition -> do
(session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition)
let d = DecomposedSession (snd edgeTransition) CurrentSessionId M.empty
let d = DecomposedSession (snd edgeTransition) CurrentSessionId emptySM
invalidateIfNeeded st (Just session) d `shouldReturn` Nothing
getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session)]
it "invalidates all of the user's sessions when AllSessionIdsOfLoggedUser is forced" $ do
forM_ allEdges $ \edgeTransition -> do
(session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition)
let d = DecomposedSession (snd edgeTransition) AllSessionIdsOfLoggedUser M.empty
let d = DecomposedSession (snd edgeTransition) AllSessionIdsOfLoggedUser emptySM
invalidateIfNeeded st (Just session) d `shouldReturn` Nothing
let expected = DeleteSession (sessionKey session) :
maybe [] ((:[]) . DeleteAllSessionsOfAuthId) (snd edgeTransition)
@ -296,19 +296,19 @@ main = hspec $ parallel $ do
let oldSession = Session
{ sessionKey = S "123456789-123456789-1234"
, sessionAuthId = Just "auth"
, sessionData = M.fromList [("a", "b"), ("c", "d")]
, sessionData = mkSessionMap [("a", "b"), ("c", "d")]
, sessionCreatedAt = TI.addUTCTime (-10) fakenow
, sessionAccessedAt = TI.addUTCTime (-5) fakenow }
sto <- prepareMockStorage [oldSession]
st <- createState sto
return (oldSession, sto, st)
emptyDecomp = DecomposedSession Nothing DoNotForceInvalidate M.empty
return (oldSession, sto :: MockStorage SessionMap, st)
emptyDecomp = DecomposedSession Nothing DoNotForceInvalidate emptySM
it "inserts new sessions when there wasn't an old one" $ do
sto <- emptyMockStorage
st <- createState sto
st <- createState (sto :: MockStorage SessionMap)
let d = DecomposedSession a DoNotForceInvalidate m
m = M.fromList [("a", "b"), ("c", "d")]
m = mkSessionMap [("a", "b"), ("c", "d")]
a = Just "auth"
Just session <- saveSessionOnDb st fakenow Nothing d
getMockOperations sto `shouldReturn` [InsertSession session]
@ -320,7 +320,7 @@ main = hspec $ parallel $ do
it "replaces sesssions when there was an old one" $ do
(oldSession, sto, st) <- prepareSaveSessionOnDb
let d = DecomposedSession Nothing DoNotForceInvalidate m
m = M.fromList [("a", "b"), ("x", "y")]
m = mkSessionMap [("a", "b"), ("x", "y")]
Just session <- saveSessionOnDb st fakenow (Just oldSession) d
getMockOperations sto `shouldReturn` [ReplaceSession session]
session `shouldBe` oldSession
@ -336,7 +336,7 @@ main = hspec $ parallel $ do
it "saves session if it's empty but there was an old one" $ do
(oldSession, sto, st) <- prepareSaveSessionOnDb
let newSession = oldSession { sessionData = M.empty
let newSession = oldSession { sessionData = emptySM
, sessionAuthId = Nothing
, sessionAccessedAt = fakenow }
saveSessionOnDb st fakenow (Just oldSession) emptyDecomp `shouldReturn` Just newSession
@ -356,46 +356,42 @@ main = hspec $ parallel $ do
saveSessionOnDb st (t 1) (Just session1) d `shouldReturn` Just session2
getMockOperations sto `shouldReturn` [ReplaceSession session2]
describe "decomposeSession" $ do
describe "decomposeSession/SessionMap" $ do
let authKey_ = authKey stnull
prop "it is sane when not finding auth key or force invalidate key" $
\data_ ->
let sessionMap = mkSessionMap $ filter (notSpecial . fst) $ data_
notSpecial = flip notElem [authKey stnull, forceInvalidateKey] . T.pack
in decomposeSession stnull sessionMap `shouldBe`
in decomposeSession authKey_ sessionMap `shouldBe`
DecomposedSession Nothing DoNotForceInvalidate sessionMap
prop "parses the force invalidate key" $
\data_ ->
let sessionMap v = M.insert forceInvalidateKey (B8.pack $ show v) $ mkSessionMap data_
let sessionMap v = onSM (M.insert forceInvalidateKey (B8.pack $ show v)) $ mkSessionMap data_
allForces = [minBound..maxBound] :: [ForceInvalidate]
test v = dsForceInvalidate (decomposeSession stnull $ sessionMap v) Q.=== v
test v = dsForceInvalidate (decomposeSession authKey_ $ sessionMap v) Q.=== v
in Q.conjoin (test <$> allForces)
it "removes the auth key" $ do
let m = M.singleton "a" "b"; m' = M.insert (authKey stnull) "x" m
decomposeSession stnull m' `shouldBe`
DecomposedSession (Just "x") DoNotForceInvalidate m
decomposeSession authKey_ (SessionMap m') `shouldBe`
DecomposedSession (Just "x") DoNotForceInvalidate (SessionMap m)
describe "toSessionMap" $ do
let mkSession authId data_ = Session
{ sessionKey = error "irrelevant 1"
, sessionAuthId = authId
, sessionData = mkSessionMap data_
, sessionCreatedAt = error "irrelevant 2"
, sessionAccessedAt = error "irrelevant 3"
}
describe "recomposeSession/SessionMap" $ do
let authKey_ = authKey stnull
prop "does not change session data for sessions without auth ID" $
\data_ ->
let s = mkSession Nothing data_
in toSessionMap stnull s Q.=== sessionData s
let s = mkSessionMap data_
in recomposeSession authKey_ Nothing s Q.=== s
prop "adds (overwriting) the auth ID to the session data" $
\authId_ data_ ->
let s = mkSession (Just authId) ((T.unpack k, "foo") : data_)
k = authKey stnull
let s = mkSessionMap ((T.unpack authKey_, "foo") : data_)
authId = B8.pack authId_
in toSessionMap stnull s Q.=== M.adjust (const authId) k (sessionData s)
in recomposeSession authKey_ (Just authId) s
Q.=== onSM (M.adjust (const authId) authKey_) s
describe "MockStorage" $ do
sto <- runIO emptyMockStorage
@ -404,7 +400,19 @@ main = hspec $ parallel $ do
-- | Used to generate session maps on QuickCheck properties.
mkSessionMap :: [(String, String)] -> SessionMap
mkSessionMap = M.fromList . map (T.pack *** B8.pack)
mkSessionMap = SessionMap . M.fromList . map (T.pack *** B8.pack)
-- | Apply a function to a 'SessionMap'.
onSM
:: (M.Map T.Text B8.ByteString -> M.Map T.Text B8.ByteString)
-> (SessionMap -> SessionMap)
onSM f = SessionMap . f . unSessionMap
-- | Empty 'SessionMap'.
emptySM :: SessionMap
emptySM = emptySession
----------------------------------------------------------------------
@ -416,6 +424,7 @@ data TNTStorage = TNTStorage deriving (Typeable)
instance Storage TNTStorage where
type TransactionM TNTStorage = IO
type SessionData TNTStorage = TNTSessionData
runTransactionM _ = id
getSession = explode "getSession"
deleteSession = explode "deleteSession"
@ -436,29 +445,52 @@ data TNTExplosion = TNTExplosion String String deriving (Show, Typeable)
instance E.Exception TNTExplosion where
-- | Session data that explodes if it's used. Doesn't explode on
-- 'emptySession'.
data TNTSessionData = TNTSessionData deriving (Eq, Show, Typeable)
instance IsSessionData TNTSessionData where
type Decomposed TNTSessionData = ()
emptySession = TNTSessionData
isSameDecomposed _ = curry (explodeD "isSameDecomposed")
decomposeSession = curry (explodeD "decomposeSession")
recomposeSession = (curry . curry) (explodeD "recomposeSession")
isDecomposedEmpty _ = explodeD "isDecomposedEmpty"
-- | Implementation of all 'IsSessionData' methods of
-- 'TNTSessionData'.
explodeD :: Show a => String -> a -> b
explodeD fun = E.throw . TNTExplosion fun . show
----------------------------------------------------------------------
-- | A mock operation that was executed.
data MockOperation =
GetSession SessionId
| DeleteSession SessionId
data MockOperation sess =
GetSession (SessionId sess)
| DeleteSession (SessionId sess)
| DeleteAllSessionsOfAuthId AuthId
| InsertSession Session
| ReplaceSession Session
deriving (Eq, Show, Typeable)
| InsertSession (Session sess)
| ReplaceSession (Session sess)
deriving (Typeable)
deriving instance Eq (Decomposed sess) => Eq (MockOperation sess)
deriving instance Show (Decomposed sess) => Show (MockOperation sess)
-- | A mock storage used just for testing.
data MockStorage =
data MockStorage sess =
MockStorage
{ mockSessions :: I.IORef (M.Map SessionId Session)
, mockOperations :: I.IORef [MockOperation]
{ mockSessions :: I.IORef (M.Map (SessionId sess) (Session sess))
, mockOperations :: I.IORef [MockOperation sess]
}
deriving (Typeable)
instance Storage MockStorage where
type TransactionM MockStorage = IO
instance IsSessionData sess => Storage (MockStorage sess) where
type TransactionM (MockStorage sess) = IO
type SessionData (MockStorage sess) = sess
runTransactionM _ = id
getSession sto sid = do
-- We need to use atomicModifyIORef instead of readIORef
@ -478,7 +510,7 @@ instance Storage MockStorage where
M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap
in maybe
(newMap, return ())
(\oldVal -> (oldMap, E.throwIO $ SessionAlreadyExists oldVal session))
(\oldVal -> (oldMap, mockThrow $ SessionAlreadyExists oldVal session))
moldVal
addMockOperation sto (InsertSession session)
replaceSession sto session = do
@ -486,14 +518,22 @@ instance Storage MockStorage where
let (moldVal, newMap) =
M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap
in maybe
(oldMap, E.throwIO $ SessionDoesNotExist session)
(oldMap, mockThrow $ SessionDoesNotExist session)
(const (newMap, return ()))
moldVal
addMockOperation sto (ReplaceSession session)
-- | Specialization of 'E.throwIO' for 'MockStorage'.
mockThrow
:: IsSessionData sess
=> StorageException (MockStorage sess)
-> TransactionM (MockStorage sess) a
mockThrow = E.throwIO
-- | Creates empty mock storage.
emptyMockStorage :: IO MockStorage
emptyMockStorage :: IO (MockStorage sess)
emptyMockStorage =
MockStorage
<$> I.newIORef M.empty
@ -501,7 +541,7 @@ emptyMockStorage =
-- | Creates mock storage with the given sessions already existing.
prepareMockStorage :: [Session] -> IO MockStorage
prepareMockStorage :: [Session sess] -> IO (MockStorage sess)
prepareMockStorage sessions = do
sto <- emptyMockStorage
I.writeIORef (mockSessions sto) (M.fromList [(sessionKey s, s) | s <- sessions])
@ -510,10 +550,10 @@ prepareMockStorage sessions = do
-- | Get the list of mock operations that were made and clear
-- them. The operations are listed in chronological order.
getMockOperations :: MockStorage -> IO [MockOperation]
getMockOperations :: MockStorage sess -> IO [MockOperation sess]
getMockOperations = flip I.atomicModifyIORef' ((,) [] . reverse) . mockOperations
-- | Add a mock operations to the log.
addMockOperation :: MockStorage -> MockOperation -> IO ()
addMockOperation :: MockStorage sess -> MockOperation sess -> IO ()
addMockOperation sto op = I.atomicModifyIORef' (mockOperations sto) $ \ops -> (op:ops, ())