Generalize session data (huge commit).
This commit is contained in:
parent
821016a382
commit
3e33c58af0
@ -27,9 +27,12 @@ library
|
||||
Web.ServerSession.Backend.Acid
|
||||
Web.ServerSession.Backend.Acid.Internal
|
||||
extensions:
|
||||
ConstraintKinds
|
||||
DeriveDataTypeable
|
||||
FlexibleContexts
|
||||
TemplateHaskell
|
||||
TypeFamilies
|
||||
UndecidableInstances
|
||||
ghc-options: -Wall
|
||||
|
||||
|
||||
|
||||
@ -26,14 +26,16 @@ module Web.ServerSession.Backend.Acid.Internal
|
||||
|
||||
import Control.Monad.Reader (ask)
|
||||
import Control.Monad.State (get, modify', put)
|
||||
import Data.Acid (AcidState, Query, Update, makeAcidic, query, update)
|
||||
import Data.SafeCopy (deriveSafeCopy, base)
|
||||
import Data.Acid
|
||||
import Data.Acid.Advanced
|
||||
import Data.SafeCopy
|
||||
import Data.Typeable (Typeable)
|
||||
|
||||
import qualified Control.Exception as E
|
||||
import qualified Data.Map.Strict as M
|
||||
import qualified Data.Set as S
|
||||
import qualified Web.ServerSession.Core as SS
|
||||
import qualified Web.ServerSession.Core.Internal as SSI
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
@ -41,13 +43,13 @@ import qualified Web.ServerSession.Core as SS
|
||||
|
||||
-- | Map from session IDs to sessions. The most important map,
|
||||
-- allowing us efficient access to a session given its ID.
|
||||
type SessionIdToSession = M.Map SS.SessionId SS.Session
|
||||
type SessionIdToSession sess = M.Map (SS.SessionId sess) (SS.Session sess)
|
||||
|
||||
|
||||
-- | Map from auth IDs to session IDs. Allow us to invalidate
|
||||
-- all sessions of given user without having to iterate through
|
||||
-- the whole 'SessionIdToSession' map.
|
||||
type AuthIdToSessionId = M.Map SS.AuthId (S.Set SS.SessionId)
|
||||
type AuthIdToSessionId sess = M.Map SS.AuthId (S.Set (SS.SessionId sess))
|
||||
|
||||
|
||||
-- | The current sessions.
|
||||
@ -55,33 +57,37 @@ type AuthIdToSessionId = M.Map SS.AuthId (S.Set SS.SessionId)
|
||||
-- Besides the obvious map from session IDs to sessions, we also
|
||||
-- maintain a map of auth IDs to session IDs. This allow us to
|
||||
-- quickly invalidate all sessions of a given user.
|
||||
data ServerSessionAcidState =
|
||||
data ServerSessionAcidState sess =
|
||||
ServerSessionAcidState
|
||||
{ sessionIdToSession :: !SessionIdToSession
|
||||
, authIdToSessionId :: !AuthIdToSessionId
|
||||
} deriving (Show, Typeable)
|
||||
|
||||
deriveSafeCopy 0 'base ''SS.SessionId -- dangerous!
|
||||
deriveSafeCopy 0 'base ''SS.Session -- dangerous!
|
||||
deriveSafeCopy 0 'base ''ServerSessionAcidState
|
||||
{ sessionIdToSession :: !(SessionIdToSession sess)
|
||||
, authIdToSessionId :: !(AuthIdToSessionId sess)
|
||||
} deriving (Typeable)
|
||||
|
||||
|
||||
-- | Empty 'ServerSessionAcidState' used to bootstrap the 'AcidState'.
|
||||
emptyState :: ServerSessionAcidState
|
||||
emptyState :: ServerSessionAcidState sess
|
||||
emptyState = ServerSessionAcidState M.empty M.empty
|
||||
|
||||
|
||||
-- | Remove the given 'SessionId' from the set of the given
|
||||
-- 'AuthId' on the map. Does not do anything if no 'AuthId' is
|
||||
-- provided.
|
||||
removeSessionFromAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId -> AuthIdToSessionId
|
||||
removeSessionFromAuthId
|
||||
:: SS.SessionId sess
|
||||
-> Maybe SS.AuthId
|
||||
-> AuthIdToSessionId sess
|
||||
-> AuthIdToSessionId sess
|
||||
removeSessionFromAuthId sid = maybe id (M.update (nothingfy . S.delete sid))
|
||||
|
||||
|
||||
-- | Insert the given session ID as being part of the given auth
|
||||
-- ID. Conceptually the opposite of 'removeSessionFromAuthId'.
|
||||
-- Does not do anything if no 'AuthId' is provided.
|
||||
insertSessionForAuthId :: SS.SessionId -> Maybe SS.AuthId -> AuthIdToSessionId -> AuthIdToSessionId
|
||||
insertSessionForAuthId
|
||||
:: SS.SessionId sess
|
||||
-> Maybe SS.AuthId
|
||||
-> AuthIdToSessionId sess
|
||||
-> AuthIdToSessionId sess
|
||||
insertSessionForAuthId sid = maybe id (flip (M.insertWith S.union) (S.singleton sid))
|
||||
|
||||
|
||||
@ -93,13 +99,61 @@ nothingfy s = if S.null s then Nothing else Just s
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
deriveSafeCopy 0 'base ''SS.SessionMap
|
||||
|
||||
|
||||
-- | We can't @deriveSafeCopy 0 'base ''SS.SessionId@ as
|
||||
-- otherwise we'd require an unneeded @SafeCopy sess@.
|
||||
instance SafeCopy (SS.SessionId sess) where
|
||||
putCopy = contain . safePut . SSI.unS
|
||||
getCopy = contain $ SSI.S <$> safeGet
|
||||
|
||||
|
||||
-- | We can't @deriveSafeCopy 0 'base ''SS.Session@ due to the
|
||||
-- required context.
|
||||
instance SafeCopy (SS.Decomposed sess) => SafeCopy (SS.Session sess) where
|
||||
putCopy (SS.Session key authId data_ createdAt accessedAt) = contain $ do
|
||||
put_t <- getSafePut
|
||||
safePut key
|
||||
safePut authId
|
||||
safePut data_
|
||||
put_t createdAt
|
||||
put_t accessedAt
|
||||
getCopy = contain $ do
|
||||
get_t <- getSafeGet
|
||||
SS.Session
|
||||
<$> safeGet
|
||||
<*> safeGet
|
||||
<*> safeGet
|
||||
<*> get_t
|
||||
<*> get_t
|
||||
|
||||
|
||||
-- | We can't @deriveSafeCopy 0 'base ''ServerSessionAcidState@ due
|
||||
-- to the required context.
|
||||
instance SafeCopy (SS.Decomposed sess) => SafeCopy (ServerSessionAcidState sess) where
|
||||
putCopy (ServerSessionAcidState sits aits) = contain $ do
|
||||
safePut sits
|
||||
safePut aits
|
||||
getCopy = contain $ ServerSessionAcidState <$> safeGet <*> safeGet
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Get the session for the given session ID.
|
||||
getSession :: SS.SessionId -> Query ServerSessionAcidState (Maybe SS.Session)
|
||||
getSession
|
||||
:: SS.Storage (AcidStorage sess)
|
||||
=> SS.SessionId sess
|
||||
-> Query (ServerSessionAcidState sess) (Maybe (SS.Session sess))
|
||||
getSession sid = M.lookup sid . sessionIdToSession <$> ask
|
||||
|
||||
|
||||
-- | Delete the session with given session ID.
|
||||
deleteSession :: SS.SessionId -> Update ServerSessionAcidState ()
|
||||
deleteSession
|
||||
:: SS.Storage (AcidStorage sess)
|
||||
=> SS.SessionId sess
|
||||
-> Update (ServerSessionAcidState sess) ()
|
||||
deleteSession sid = do
|
||||
let removeSession = M.updateLookupWithKey (\_ _ -> Nothing) sid
|
||||
modify' $ \state ->
|
||||
@ -110,7 +164,10 @@ deleteSession sid = do
|
||||
|
||||
|
||||
-- | Delete all sessions of the given auth ID.
|
||||
deleteAllSessionsOfAuthId :: SS.AuthId -> Update ServerSessionAcidState ()
|
||||
deleteAllSessionsOfAuthId
|
||||
:: SS.Storage (AcidStorage sess)
|
||||
=> SS.AuthId
|
||||
-> Update (ServerSessionAcidState sess) ()
|
||||
deleteAllSessionsOfAuthId authId = do
|
||||
let removeSession = maybe id (flip M.difference . M.fromSet (const ()))
|
||||
removeAuth = M.updateLookupWithKey (\_ _ -> Nothing) authId
|
||||
@ -121,11 +178,14 @@ deleteAllSessionsOfAuthId authId = do
|
||||
|
||||
|
||||
-- | Insert a new session.
|
||||
insertSession :: SS.Session -> Update ServerSessionAcidState ()
|
||||
insertSession
|
||||
:: SS.Storage (AcidStorage sess)
|
||||
=> SS.Session sess
|
||||
-> Update (ServerSessionAcidState sess) ()
|
||||
insertSession session = do
|
||||
let insertSess s =
|
||||
let (mold, new) = M.insertLookupWithKey (\_ v _ -> v) sid session s
|
||||
in maybe new (\old -> E.throw $ SS.SessionAlreadyExists old session) mold
|
||||
in maybe new (\old -> throwAS $ SS.SessionAlreadyExists old session) mold
|
||||
insertAuth = insertSessionForAuthId sid (SS.sessionAuthId session)
|
||||
sid = SS.sessionKey session
|
||||
modify' $ \state ->
|
||||
@ -135,14 +195,17 @@ insertSession session = do
|
||||
|
||||
|
||||
-- | Replace the contents of a session.
|
||||
replaceSession :: SS.Session -> Update ServerSessionAcidState ()
|
||||
replaceSession
|
||||
:: SS.Storage (AcidStorage sess)
|
||||
=> SS.Session sess
|
||||
-> Update (ServerSessionAcidState sess) ()
|
||||
replaceSession session = do
|
||||
-- Check that the old session exists while replacing it.
|
||||
ServerSessionAcidState sits aits <- get
|
||||
let (moldSession, sits') = M.insertLookupWithKey (\_ v _ -> v) sid session sits
|
||||
sid = SS.sessionKey session
|
||||
case moldSession of
|
||||
Nothing -> E.throw $ SS.SessionDoesNotExist session
|
||||
Nothing -> throwAS $ SS.SessionDoesNotExist session
|
||||
Just oldSession -> do
|
||||
-- Remove/insert the old auth ID from the map if needed.
|
||||
let modAits | oldAuthId == newAuthId = id
|
||||
@ -155,27 +218,103 @@ replaceSession session = do
|
||||
put (ServerSessionAcidState sits' aits')
|
||||
|
||||
|
||||
-- | Specialization of 'E.throw' for 'AcidStorage'.
|
||||
throwAS
|
||||
:: SS.Storage (AcidStorage sess)
|
||||
=> SS.StorageException (AcidStorage sess)
|
||||
-> a
|
||||
throwAS = E.throw
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
makeAcidic ''ServerSessionAcidState ['getSession, 'deleteSession, 'deleteAllSessionsOfAuthId, 'insertSession, 'replaceSession]
|
||||
|
||||
|
||||
-- | Session storage backend using @acid-state@.
|
||||
newtype AcidStorage =
|
||||
newtype AcidStorage sess =
|
||||
AcidStorage
|
||||
{ acidState :: AcidState ServerSessionAcidState
|
||||
{ acidState :: AcidState (ServerSessionAcidState sess)
|
||||
-- ^ Open 'AcidState' of server sessions.
|
||||
} deriving (Typeable)
|
||||
|
||||
|
||||
-- | We do not provide any ACID guarantees for different actions
|
||||
-- running inside the same @TransactionM AcidStorage@.
|
||||
instance SS.Storage AcidStorage where
|
||||
type TransactionM AcidStorage = IO
|
||||
instance ( SS.IsSessionData sess
|
||||
, SafeCopy sess
|
||||
, SafeCopy (SS.Decomposed sess)
|
||||
) => SS.Storage (AcidStorage sess) where
|
||||
type SessionData (AcidStorage sess) = sess
|
||||
type TransactionM (AcidStorage sess) = IO
|
||||
runTransactionM = const id
|
||||
getSession (AcidStorage s) = query s . GetSession
|
||||
deleteSession (AcidStorage s) = update s . DeleteSession
|
||||
deleteAllSessionsOfAuthId (AcidStorage s) = update s . DeleteAllSessionsOfAuthId
|
||||
insertSession (AcidStorage s) = update s . InsertSession
|
||||
replaceSession (AcidStorage s) = update s . ReplaceSession
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
-- makeAcidic can't handle type variables, so we have to do
|
||||
-- everything by hand. :(
|
||||
|
||||
data GetSession sess = GetSession (SS.SessionId sess)
|
||||
data DeleteSession sess = DeleteSession (SS.SessionId sess)
|
||||
data DeleteAllSessionsOfAuthId sess = DeleteAllSessionsOfAuthId SS.AuthId
|
||||
data InsertSession sess = InsertSession (SS.Session sess)
|
||||
data ReplaceSession sess = ReplaceSession (SS.Session sess)
|
||||
|
||||
instance SafeCopy (GetSession sess) where
|
||||
putCopy (GetSession v) = contain $ safePut v
|
||||
getCopy = contain $ GetSession <$> safeGet
|
||||
|
||||
instance SafeCopy (DeleteSession sess) where
|
||||
putCopy (DeleteSession v) = contain $ safePut v
|
||||
getCopy = contain $ DeleteSession <$> safeGet
|
||||
|
||||
instance SafeCopy (DeleteAllSessionsOfAuthId sess) where
|
||||
putCopy (DeleteAllSessionsOfAuthId v) = contain $ safePut v
|
||||
getCopy = contain $ DeleteAllSessionsOfAuthId <$> safeGet
|
||||
|
||||
instance SafeCopy (SS.Decomposed sess) => SafeCopy (InsertSession sess) where
|
||||
putCopy (InsertSession v) = contain $ safePut v
|
||||
getCopy = contain $ InsertSession <$> safeGet
|
||||
|
||||
instance SafeCopy (SS.Decomposed sess) => SafeCopy (ReplaceSession sess) where
|
||||
putCopy (ReplaceSession v) = contain $ safePut v
|
||||
getCopy = contain $ ReplaceSession <$> safeGet
|
||||
|
||||
type AcidContext sess =
|
||||
( SS.IsSessionData sess
|
||||
, SafeCopy sess
|
||||
, SafeCopy (SS.Decomposed sess) )
|
||||
|
||||
instance AcidContext sess => QueryEvent (GetSession sess)
|
||||
instance AcidContext sess => UpdateEvent (DeleteSession sess)
|
||||
instance AcidContext sess => UpdateEvent (DeleteAllSessionsOfAuthId sess)
|
||||
instance AcidContext sess => UpdateEvent (InsertSession sess)
|
||||
instance AcidContext sess => UpdateEvent (ReplaceSession sess)
|
||||
|
||||
instance AcidContext sess => Method (GetSession sess) where
|
||||
type MethodResult (GetSession sess) = Maybe (SS.Session sess)
|
||||
type MethodState (GetSession sess) = ServerSessionAcidState sess
|
||||
instance AcidContext sess => Method (DeleteSession sess) where
|
||||
type MethodResult (DeleteSession sess) = ()
|
||||
type MethodState (DeleteSession sess) = ServerSessionAcidState sess
|
||||
instance AcidContext sess => Method (DeleteAllSessionsOfAuthId sess) where
|
||||
type MethodResult (DeleteAllSessionsOfAuthId sess) = ()
|
||||
type MethodState (DeleteAllSessionsOfAuthId sess) = ServerSessionAcidState sess
|
||||
instance AcidContext sess => Method (InsertSession sess) where
|
||||
type MethodResult (InsertSession sess) = ()
|
||||
type MethodState (InsertSession sess) = ServerSessionAcidState sess
|
||||
instance AcidContext sess => Method (ReplaceSession sess) where
|
||||
type MethodResult (ReplaceSession sess) = ()
|
||||
type MethodState (ReplaceSession sess) = ServerSessionAcidState sess
|
||||
|
||||
instance AcidContext sess => IsAcidic (ServerSessionAcidState sess) where
|
||||
acidEvents =
|
||||
[ QueryEvent $ \(GetSession sid) -> getSession sid
|
||||
, UpdateEvent $ \(DeleteSession sid) -> deleteSession sid
|
||||
, UpdateEvent $ \(DeleteAllSessionsOfAuthId authId) -> deleteAllSessionsOfAuthId authId
|
||||
, UpdateEvent $ \(InsertSession session) -> insertSession session
|
||||
, UpdateEvent $ \(ReplaceSession session) -> replaceSession session ]
|
||||
|
||||
@ -24,7 +24,7 @@ library
|
||||
, containers
|
||||
, path-pieces
|
||||
, persistent == 2.1.*
|
||||
, persistent-template == 2.1.*
|
||||
, tagged >= 0.8
|
||||
, text
|
||||
, time
|
||||
, transformers
|
||||
@ -37,14 +37,19 @@ library
|
||||
extensions:
|
||||
DeriveDataTypeable
|
||||
EmptyDataDecls
|
||||
FlexibleContexts
|
||||
FlexibleInstances
|
||||
GADTs
|
||||
GeneralizedNewtypeDeriving
|
||||
OverloadedStrings
|
||||
PatternGuards
|
||||
QuasiQuotes
|
||||
RecordWildCards
|
||||
ScopedTypeVariables
|
||||
StandaloneDeriving
|
||||
TemplateHaskell
|
||||
TypeFamilies
|
||||
UndecidableInstances
|
||||
ghc-options: -Wall
|
||||
|
||||
|
||||
|
||||
@ -22,9 +22,11 @@
|
||||
-- share [mkPersist sqlSettings, mkSave \"entityDefs\"]
|
||||
--
|
||||
-- -- On Application.hs
|
||||
-- import Data.Proxy (Proxy(..)) -- tagged package, or base from GHC 7.10 onwards
|
||||
-- import Web.ServerSession.Core (SessionMap)
|
||||
-- import Web.ServerSession.Backend.Persistent (serverSessionDefs)
|
||||
--
|
||||
-- mkMigrate \"migrateAll\" (serverSessionDefs ++ entityDefs)
|
||||
-- mkMigrate \"migrateAll\" (serverSessionDefs (Proxy :: Proxy SessionMap) ++ entityDefs)
|
||||
--
|
||||
-- makeFoundation =
|
||||
-- ...
|
||||
@ -32,6 +34,8 @@
|
||||
-- ...
|
||||
-- @
|
||||
--
|
||||
-- If you're not using @SessionMap@, just change @Proxy@ type above.
|
||||
--
|
||||
-- If you forget to setup the migration above, this session
|
||||
-- storage backend will fail at runtime as the required table
|
||||
-- will not exist.
|
||||
|
||||
@ -9,68 +9,284 @@ module Web.ServerSession.Backend.Persistent.Internal.Impl
|
||||
, toPersistentSession
|
||||
, fromPersistentSession
|
||||
, SqlStorage(..)
|
||||
, throwSS
|
||||
) where
|
||||
|
||||
import Control.Monad (void)
|
||||
import Control.Monad.IO.Class (liftIO)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Data.Time (UTCTime)
|
||||
import Data.Typeable (Typeable)
|
||||
import Database.Persist (PersistEntity(..))
|
||||
import Database.Persist.TH (mkPersist, mkSave, persistLowerCase, share, sqlSettings)
|
||||
import Web.PathPieces (PathPiece)
|
||||
import Web.ServerSession.Core
|
||||
|
||||
import qualified Control.Exception as E
|
||||
import qualified Data.Aeson as A
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Text as T
|
||||
import qualified Database.Persist as P
|
||||
import qualified Database.Persist.Sql as P
|
||||
|
||||
import Web.ServerSession.Backend.Persistent.Internal.Types
|
||||
|
||||
|
||||
share
|
||||
[mkPersist sqlSettings, mkSave "serverSessionDefs"]
|
||||
[persistLowerCase|
|
||||
PersistentSession json
|
||||
key SessionId -- Session ID, primary key.
|
||||
authId ByteStringJ Maybe -- Value of "_ID" session key.
|
||||
session SessionMapJ -- Rest of the session data.
|
||||
createdAt UTCTime -- When this session was created.
|
||||
accessedAt UTCTime -- When this session was last accessed.
|
||||
Primary key
|
||||
deriving Eq Ord Show Typeable
|
||||
|]
|
||||
-- We can't use the Template Haskell since we want to generalize
|
||||
-- some fields.
|
||||
--
|
||||
-- This is going to be a pain to upgrade when the next major
|
||||
-- persistent version comes :(.
|
||||
|
||||
-- | Entity corresponding to a 'Session'.
|
||||
--
|
||||
-- We're bending @persistent@ in ways it wasn't expected to. In
|
||||
-- particular, this entity is parametrized over the session type.
|
||||
data PersistentSession sess =
|
||||
PersistentSession
|
||||
{ persistentSessionKey :: !(SessionId sess) -- ^ Session ID, primary key.
|
||||
, persistentSessionAuthId :: !(Maybe ByteStringJ) -- ^ Value of "_ID" session key.
|
||||
, persistentSessionSession :: !(Decomposed sess) -- ^ Rest of the session data.
|
||||
, persistentSessionCreatedAt :: !UTCTime -- ^ When this session was created.
|
||||
, persistentSessionAccessedAt :: !UTCTime -- ^ When this session was last accessed.
|
||||
} deriving (Typeable)
|
||||
|
||||
deriving instance Eq (Decomposed sess) => Eq (PersistentSession sess)
|
||||
deriving instance Ord (Decomposed sess) => Ord (PersistentSession sess)
|
||||
deriving instance Show (Decomposed sess) => Show (PersistentSession sess)
|
||||
|
||||
|
||||
type PersistentSessionId sess = Key (PersistentSession sess)
|
||||
|
||||
instance forall sess. P.PersistFieldSql (Decomposed sess) => P.PersistEntity (PersistentSession sess) where
|
||||
type PersistEntityBackend (PersistentSession sess) = P.SqlBackend
|
||||
|
||||
data Unique (PersistentSession sess)
|
||||
|
||||
newtype Key (PersistentSession sess) =
|
||||
PersistentSessionKey' {unPersistentSessionKey :: SessionId sess}
|
||||
deriving ( Eq, Ord, Show, Read, PathPiece
|
||||
, P.PersistField, P.PersistFieldSql, A.ToJSON, A.FromJSON )
|
||||
|
||||
data EntityField (PersistentSession sess) typ =
|
||||
typ ~ PersistentSessionId sess => PersistentSessionId
|
||||
| typ ~ SessionId sess => PersistentSessionKey
|
||||
| typ ~ Maybe ByteStringJ => PersistentSessionAuthId
|
||||
| typ ~ Decomposed sess => PersistentSessionSession
|
||||
| typ ~ UTCTime => PersistentSessionCreatedAt
|
||||
| typ ~ UTCTime => PersistentSessionAccessedAt
|
||||
|
||||
keyToValues = (:[]) . P.toPersistValue . unPersistentSessionKey
|
||||
keyFromValues [x] | Right v <- P.fromPersistValue x = Right $ PersistentSessionKey' v
|
||||
keyFromValues xs = Left $ T.pack $ "PersistentSession/keyFromValues: " ++ show xs
|
||||
|
||||
entityDef _
|
||||
= P.EntityDef
|
||||
(P.HaskellName "PersistentSession")
|
||||
(P.DBName "persistent_session")
|
||||
(pfd PersistentSessionId)
|
||||
["json"]
|
||||
[ pfd PersistentSessionKey
|
||||
, pfd PersistentSessionAuthId
|
||||
, pfd PersistentSessionSession
|
||||
, pfd PersistentSessionCreatedAt
|
||||
, pfd PersistentSessionAccessedAt ]
|
||||
[]
|
||||
[]
|
||||
["Eq", "Ord", "Show", "Typeable"]
|
||||
M.empty
|
||||
False
|
||||
where
|
||||
pfd :: P.EntityField (PersistentSession sess) typ -> P.FieldDef
|
||||
pfd = P.persistFieldDef
|
||||
|
||||
toPersistFields (PersistentSession a b c d e) =
|
||||
[ P.SomePersistField a
|
||||
, P.SomePersistField b
|
||||
, P.SomePersistField c
|
||||
, P.SomePersistField d
|
||||
, P.SomePersistField e ]
|
||||
|
||||
fromPersistValues [a, b, c, d, e] =
|
||||
PersistentSession
|
||||
<$> err "key" (P.fromPersistValue a)
|
||||
<*> err "authId" (P.fromPersistValue b)
|
||||
<*> err "session" (P.fromPersistValue c)
|
||||
<*> err "createdAt" (P.fromPersistValue d)
|
||||
<*> err "accessedAt" (P.fromPersistValue e)
|
||||
where
|
||||
err :: T.Text -> Either T.Text a -> Either T.Text a
|
||||
err s (Left r) = Left $ T.concat ["PersistentSession/fromPersistValues/", s, ": ", r]
|
||||
err _ (Right v) = Right v
|
||||
fromPersistValues x = Left $ T.pack $ "PersistentSession/fromPersistValues: " ++ show x
|
||||
|
||||
persistUniqueToFieldNames _ = error "Degenerate case, should never happen"
|
||||
persistUniqueToValues _ = error "Degenerate case, should never happen"
|
||||
persistUniqueKeys _ = []
|
||||
|
||||
persistFieldDef PersistentSessionId
|
||||
= P.FieldDef
|
||||
(P.HaskellName "Id")
|
||||
(P.DBName "id")
|
||||
(P.FTTypeCon
|
||||
Nothing "PersistentSessionId")
|
||||
(P.SqlOther "Composite Reference")
|
||||
[]
|
||||
True
|
||||
(P.CompositeRef
|
||||
(P.CompositeDef
|
||||
[P.FieldDef
|
||||
(P.HaskellName "key")
|
||||
(P.DBName "key")
|
||||
(P.FTTypeCon Nothing "SessionId")
|
||||
(P.SqlOther "SqlType unset for key")
|
||||
[]
|
||||
True
|
||||
P.NoReference]
|
||||
[]))
|
||||
persistFieldDef PersistentSessionKey
|
||||
= P.FieldDef
|
||||
(P.HaskellName "key")
|
||||
(P.DBName "key")
|
||||
(P.FTTypeCon Nothing "SessionId sess")
|
||||
(P.sqlType (Proxy :: Proxy (SessionId sess)))
|
||||
[]
|
||||
True
|
||||
P.NoReference
|
||||
persistFieldDef PersistentSessionAuthId
|
||||
= P.FieldDef
|
||||
(P.HaskellName "authId")
|
||||
(P.DBName "auth_id")
|
||||
(P.FTTypeCon Nothing "ByteStringJ")
|
||||
(P.sqlType (Proxy :: Proxy ByteStringJ))
|
||||
["Maybe"]
|
||||
True
|
||||
P.NoReference
|
||||
persistFieldDef PersistentSessionSession
|
||||
= P.FieldDef
|
||||
(P.HaskellName "session")
|
||||
(P.DBName "session")
|
||||
(P.FTTypeCon Nothing "Decomposed sess")
|
||||
(P.sqlType (Proxy :: Proxy (Decomposed sess))) -- Important!
|
||||
[]
|
||||
True
|
||||
P.NoReference
|
||||
persistFieldDef PersistentSessionCreatedAt
|
||||
= P.FieldDef
|
||||
(P.HaskellName "createdAt")
|
||||
(P.DBName "created_at")
|
||||
(P.FTTypeCon Nothing "UTCTime")
|
||||
(P.sqlType (Proxy :: Proxy UTCTime))
|
||||
[]
|
||||
True
|
||||
P.NoReference
|
||||
persistFieldDef PersistentSessionAccessedAt
|
||||
= P.FieldDef
|
||||
(P.HaskellName "accessedAt")
|
||||
(P.DBName "accessed_at")
|
||||
(P.FTTypeCon Nothing "UTCTime")
|
||||
(P.sqlType (Proxy :: Proxy UTCTime))
|
||||
[]
|
||||
True
|
||||
P.NoReference
|
||||
|
||||
persistIdField = PersistentSessionId
|
||||
|
||||
fieldLens PersistentSessionId = lensPTH
|
||||
P.entityKey
|
||||
(\(P.Entity _ v) k -> P.Entity k v)
|
||||
fieldLens PersistentSessionKey = lensPTH
|
||||
(persistentSessionKey . P.entityVal)
|
||||
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionKey = x}))
|
||||
fieldLens PersistentSessionAuthId = lensPTH
|
||||
(persistentSessionAuthId . P.entityVal)
|
||||
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionAuthId = x}))
|
||||
fieldLens PersistentSessionSession = lensPTH
|
||||
(persistentSessionSession . P.entityVal)
|
||||
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionSession = x}))
|
||||
fieldLens PersistentSessionCreatedAt = lensPTH
|
||||
(persistentSessionCreatedAt . P.entityVal)
|
||||
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionCreatedAt = x}))
|
||||
fieldLens PersistentSessionAccessedAt = lensPTH
|
||||
(persistentSessionAccessedAt . P.entityVal)
|
||||
(\(P.Entity k v) x -> P.Entity k (v {persistentSessionAccessedAt = x}))
|
||||
|
||||
|
||||
-- | Copy-paste from @Database.Persist.TH@. Who needs lens anyway...
|
||||
lensPTH :: Functor f => (s -> a) -> (s -> b -> t) -> (a -> f b) -> s -> f t
|
||||
lensPTH sa sbt afb s = fmap (sbt s) (afb $ sa s)
|
||||
|
||||
|
||||
instance A.ToJSON (Decomposed sess) => A.ToJSON (PersistentSession sess) where
|
||||
toJSON (PersistentSession key authId session createdAt accessedAt) =
|
||||
A.object
|
||||
[ "key" A..= key
|
||||
, "authId" A..= authId
|
||||
, "session" A..= session
|
||||
, "createdAt" A..= createdAt
|
||||
, "accessedAt" A..= accessedAt ]
|
||||
|
||||
instance A.FromJSON (Decomposed sess) => A.FromJSON (PersistentSession sess) where
|
||||
parseJSON (A.Object obj) =
|
||||
PersistentSession
|
||||
<$> obj A..: "key"
|
||||
<*> obj A..: "authId"
|
||||
<*> obj A..: "session"
|
||||
<*> obj A..: "createdAt"
|
||||
<*> obj A..: "accessedAt"
|
||||
parseJSON _ = mempty
|
||||
|
||||
instance ( A.ToJSON (Decomposed sess)
|
||||
, P.PersistFieldSql (Decomposed sess)
|
||||
) => A.ToJSON (P.Entity (PersistentSession sess)) where
|
||||
toJSON = P.entityIdToJSON
|
||||
|
||||
instance ( A.FromJSON (Decomposed sess)
|
||||
, P.PersistFieldSql (Decomposed sess)
|
||||
) => A.FromJSON (P.Entity (PersistentSession sess)) where
|
||||
parseJSON = P.entityIdFromJSON
|
||||
|
||||
|
||||
-- | Entity definitions needed to generate the SQL schema for
|
||||
-- 'SqlStorage'. Example using 'SessionMap':
|
||||
--
|
||||
-- @
|
||||
-- serverSessionDefs (Proxy :: Proxy SessionMap)
|
||||
-- @
|
||||
serverSessionDefs :: forall sess. PersistEntity (PersistentSession sess) => Proxy sess -> [P.EntityDef]
|
||||
serverSessionDefs _ = [entityDef (Proxy :: Proxy (PersistentSession sess))]
|
||||
|
||||
|
||||
-- | Generate a key to the entity from the session ID.
|
||||
psKey :: SessionId -> Key PersistentSession
|
||||
psKey :: SessionId sess -> Key (PersistentSession sess)
|
||||
psKey = PersistentSessionKey'
|
||||
|
||||
|
||||
-- | Convert from 'Session' to 'PersistentSession'.
|
||||
toPersistentSession :: Session -> PersistentSession
|
||||
toPersistentSession :: Session sess -> PersistentSession sess
|
||||
toPersistentSession Session {..} =
|
||||
PersistentSession
|
||||
{ persistentSessionKey = sessionKey
|
||||
, persistentSessionAuthId = fmap B sessionAuthId
|
||||
, persistentSessionSession = M sessionData
|
||||
, persistentSessionSession = sessionData
|
||||
, persistentSessionCreatedAt = sessionCreatedAt
|
||||
, persistentSessionAccessedAt = sessionAccessedAt
|
||||
}
|
||||
|
||||
|
||||
-- | Convert from 'PersistentSession' to 'Session'.
|
||||
fromPersistentSession :: PersistentSession -> Session
|
||||
fromPersistentSession :: PersistentSession sess -> Session sess
|
||||
fromPersistentSession PersistentSession {..} =
|
||||
Session
|
||||
{ sessionKey = persistentSessionKey
|
||||
, sessionAuthId = fmap unB persistentSessionAuthId
|
||||
, sessionData = unM persistentSessionSession
|
||||
, sessionData = persistentSessionSession
|
||||
, sessionCreatedAt = persistentSessionCreatedAt
|
||||
, sessionAccessedAt = persistentSessionAccessedAt
|
||||
}
|
||||
|
||||
|
||||
-- | SQL session storage backend using @persistent@.
|
||||
newtype SqlStorage =
|
||||
newtype SqlStorage sess =
|
||||
SqlStorage
|
||||
{ connPool :: P.ConnectionPool
|
||||
-- ^ Pool of DB connections. You may use the same pool as
|
||||
@ -78,22 +294,38 @@ newtype SqlStorage =
|
||||
} deriving (Typeable)
|
||||
|
||||
|
||||
instance Storage SqlStorage where
|
||||
type TransactionM SqlStorage = P.SqlPersistT IO
|
||||
instance forall sess.
|
||||
( IsSessionData sess
|
||||
, P.PersistFieldSql (Decomposed sess)
|
||||
) => Storage (SqlStorage sess) where
|
||||
type SessionData (SqlStorage sess) = sess
|
||||
type TransactionM (SqlStorage sess) = P.SqlPersistT IO
|
||||
runTransactionM = flip P.runSqlPool . connPool
|
||||
getSession _ = fmap (fmap fromPersistentSession) . P.get . psKey
|
||||
deleteSession _ = P.delete . psKey
|
||||
deleteAllSessionsOfAuthId _ authId = P.deleteWhere [PersistentSessionAuthId P.==. Just (B authId)]
|
||||
deleteAllSessionsOfAuthId _ authId =
|
||||
P.deleteWhere [field P.==. Just (B authId)]
|
||||
where
|
||||
field :: EntityField (PersistentSession sess) (Maybe ByteStringJ)
|
||||
field = PersistentSessionAuthId
|
||||
insertSession s session = do
|
||||
mold <- getSession s (sessionKey session)
|
||||
maybe
|
||||
(void $ P.insert $ toPersistentSession session)
|
||||
(\old -> liftIO $ E.throwIO $ SessionAlreadyExists old session)
|
||||
(\old -> throwSS $ SessionAlreadyExists old session)
|
||||
mold
|
||||
replaceSession _ session = do
|
||||
let key = psKey $ sessionKey session
|
||||
mold <- P.get key
|
||||
maybe
|
||||
(liftIO $ E.throwIO $ SessionDoesNotExist session)
|
||||
(throwSS $ SessionDoesNotExist session)
|
||||
(\_old -> void $ P.replace key $ toPersistentSession session)
|
||||
mold
|
||||
|
||||
|
||||
-- | Specialization of 'E.throwIO' for 'SqlStorage'.
|
||||
throwSS
|
||||
:: Storage (SqlStorage sess)
|
||||
=> StorageException (SqlStorage sess)
|
||||
-> TransactionM (SqlStorage sess) a
|
||||
throwSS = liftIO . E.throwIO
|
||||
|
||||
@ -5,7 +5,11 @@
|
||||
-- Also exports orphan instances of @PersistField{,Sql} SessionId@.
|
||||
module Web.ServerSession.Backend.Persistent.Internal.Types
|
||||
( ByteStringJ(..)
|
||||
, SessionMapJ(..)
|
||||
-- * Orphan instances
|
||||
-- ** SessionId
|
||||
-- $orphanSessionId
|
||||
-- ** SessionMap
|
||||
-- $orphanSessionMap
|
||||
) where
|
||||
|
||||
import Control.Arrow (first)
|
||||
@ -29,12 +33,19 @@ import qualified Data.Text.Encoding as TE
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Does not do sanity checks (DB is trusted).
|
||||
instance PersistField SessionId where
|
||||
-- $orphanSessionId
|
||||
--
|
||||
-- @
|
||||
-- instance 'PersistField' ('SessionId' sess)
|
||||
-- instance 'PersistFieldSql' ('SessionId' sess)
|
||||
-- @
|
||||
--
|
||||
-- Does not do sanity checks (DB is trusted).
|
||||
instance PersistField (SessionId sess) where
|
||||
toPersistValue = toPersistValue . unS
|
||||
fromPersistValue = fmap S . fromPersistValue
|
||||
|
||||
instance PersistFieldSql SessionId where
|
||||
instance PersistFieldSql (SessionId sess) where
|
||||
sqlType p = sqlType (fmap unS p)
|
||||
|
||||
|
||||
@ -66,31 +77,41 @@ instance A.ToJSON ByteStringJ where
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Newtype of a 'SessionMap' that serializes using @cereal@ on
|
||||
-- $orphanSessionMap
|
||||
--
|
||||
-- @
|
||||
-- instance 'PersistField' 'SessionMap'
|
||||
-- instance 'PersistFieldSql' 'SessionMap'
|
||||
-- instance 'S.Serialize' 'SessionMap'
|
||||
-- instance 'A.FromJSON' 'SessionMap'
|
||||
-- instance 'A.ToJSON' 'SessionMap'
|
||||
-- @
|
||||
|
||||
-- 'PersistField' for 'SessionMap' serializes using @cereal@ on
|
||||
-- the database. We tried to use @aeson@ but @cereal@ is twice
|
||||
-- faster and uses half the memory for this use case.
|
||||
newtype SessionMapJ = M { unM :: SessionMap }
|
||||
deriving (Eq, Ord, Show, Read, Typeable)
|
||||
|
||||
instance PersistField SessionMapJ where
|
||||
--
|
||||
-- The JSON instance translates to objects using base64url for
|
||||
-- the values of 'ByteString' (cf. 'ByteStringJ').
|
||||
instance PersistField SessionMap where
|
||||
toPersistValue = toPersistValue . S.encode
|
||||
fromPersistValue = fromPersistValue >=> (either (Left . T.pack) Right . S.decode)
|
||||
|
||||
instance PersistFieldSql SessionMapJ where
|
||||
instance PersistFieldSql SessionMap where
|
||||
sqlType _ = SqlBlob
|
||||
|
||||
instance S.Serialize SessionMapJ where
|
||||
put = S.put . map (first TE.encodeUtf8) . M.toAscList . unM
|
||||
get = M . M.fromAscList . map (first TE.decodeUtf8) <$> S.get
|
||||
instance S.Serialize SessionMap where
|
||||
put = S.put . map (first TE.encodeUtf8) . M.toAscList . unSessionMap
|
||||
get = SessionMap . M.fromAscList . map (first TE.decodeUtf8) <$> S.get
|
||||
|
||||
instance A.FromJSON SessionMapJ where
|
||||
instance A.FromJSON SessionMap where
|
||||
parseJSON = fmap fixup . A.parseJSON
|
||||
where
|
||||
fixup :: M.Map Text ByteStringJ -> SessionMapJ
|
||||
fixup = M . fmap unB
|
||||
fixup :: M.Map Text ByteStringJ -> SessionMap
|
||||
fixup = SessionMap . fmap unB
|
||||
|
||||
instance A.ToJSON SessionMapJ where
|
||||
instance A.ToJSON SessionMap where
|
||||
toJSON = A.toJSON . mangle
|
||||
where
|
||||
mangle :: SessionMapJ -> M.Map Text ByteStringJ
|
||||
mangle = fmap B . unM
|
||||
mangle :: SessionMap -> M.Map Text ByteStringJ
|
||||
mangle = fmap B . unSessionMap
|
||||
|
||||
@ -3,17 +3,19 @@ module Main (main) where
|
||||
import Control.Monad (forM_)
|
||||
import Control.Monad.Logger (runStderrLoggingT, runNoLoggingT)
|
||||
import Data.Pool (destroyAllResources)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Database.Persist.Postgresql (createPostgresqlPool)
|
||||
import Database.Persist.Sqlite (createSqlitePool)
|
||||
import Test.Hspec
|
||||
import Web.ServerSession.Backend.Persistent
|
||||
import Web.ServerSession.Core (SessionMap)
|
||||
import Web.ServerSession.Core.StorageTests
|
||||
|
||||
import qualified Control.Exception as E
|
||||
import qualified Database.Persist.TH as P
|
||||
import qualified Database.Persist.Sql as P
|
||||
|
||||
P.mkMigrate "migrateAll" serverSessionDefs
|
||||
P.mkMigrate "migrateAll" (serverSessionDefs (Proxy :: Proxy SessionMap))
|
||||
|
||||
main :: IO ()
|
||||
main = hspec $
|
||||
|
||||
@ -21,6 +21,7 @@ library
|
||||
, containers
|
||||
, hedis == 0.6.*
|
||||
, path-pieces
|
||||
, tagged >= 0.8
|
||||
, text
|
||||
, time >= 1.5
|
||||
, transformers
|
||||
@ -31,8 +32,10 @@ library
|
||||
Web.ServerSession.Backend.Redis.Internal
|
||||
extensions:
|
||||
DeriveDataTypeable
|
||||
FlexibleContexts
|
||||
OverloadedStrings
|
||||
RecordWildCards
|
||||
ScopedTypeVariables
|
||||
TypeFamilies
|
||||
ghc-options: -Wall
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
module Web.ServerSession.Backend.Redis
|
||||
( RedisStorage(..)
|
||||
, RedisStorageException(..)
|
||||
, RedisSession(..)
|
||||
) where
|
||||
|
||||
import Web.ServerSession.Backend.Redis.Internal
|
||||
|
||||
@ -9,6 +9,7 @@ module Web.ServerSession.Backend.Redis.Internal
|
||||
, rSessionKey
|
||||
, rAuthKey
|
||||
|
||||
, RedisSession(..)
|
||||
, parseSession
|
||||
, printSession
|
||||
, parseUTCTime
|
||||
@ -22,6 +23,7 @@ module Web.ServerSession.Backend.Redis.Internal
|
||||
, deleteAllSessionsOfAuthIdImpl
|
||||
, insertSessionImpl
|
||||
, replaceSessionImpl
|
||||
, throwRS
|
||||
) where
|
||||
|
||||
import Control.Applicative ((<$))
|
||||
@ -31,6 +33,7 @@ import Control.Monad.IO.Class (liftIO)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.List (partition)
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Proxy (Proxy(..))
|
||||
import Data.Typeable (Typeable)
|
||||
import Web.PathPieces (toPathPiece)
|
||||
import Web.ServerSession.Core
|
||||
@ -49,7 +52,7 @@ import qualified Data.Time.Format as TI
|
||||
|
||||
|
||||
-- | Session storage backend using Redis via the @hedis@ package.
|
||||
newtype RedisStorage =
|
||||
newtype RedisStorage sess =
|
||||
RedisStorage
|
||||
{ connPool :: R.Connection
|
||||
-- ^ Connection pool to the Redis server.
|
||||
@ -58,8 +61,9 @@ newtype RedisStorage =
|
||||
|
||||
-- | We do not provide any ACID guarantees for different actions
|
||||
-- running inside the same @TransactionM RedisStorage@.
|
||||
instance Storage RedisStorage where
|
||||
type TransactionM RedisStorage = R.Redis
|
||||
instance RedisSession sess => Storage (RedisStorage sess) where
|
||||
type SessionData (RedisStorage sess) = sess
|
||||
type TransactionM (RedisStorage sess) = R.Redis
|
||||
runTransactionM = R.runRedis . connPool
|
||||
getSession _ = getSessionImpl
|
||||
deleteSession _ = deleteSessionImpl
|
||||
@ -102,7 +106,7 @@ unwrap act = act >>= either (liftIO . E.throwIO . ExpectedRight) return
|
||||
|
||||
|
||||
-- | Redis key for the given session ID.
|
||||
rSessionKey :: SessionId -> ByteString
|
||||
rSessionKey :: SessionId sess -> ByteString
|
||||
rSessionKey = B.append "ssr:session:" . TE.encodeUtf8 . toPathPiece
|
||||
|
||||
|
||||
@ -114,8 +118,39 @@ rAuthKey = B.append "ssr:authid:"
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Class for data types that can be used as session data for
|
||||
-- the Redis backend.
|
||||
--
|
||||
-- It should hold that
|
||||
--
|
||||
-- @
|
||||
-- fromHash p . perm . toHash p === id
|
||||
-- @
|
||||
--
|
||||
-- for all list permutations @perm :: [a] -> [a]@,
|
||||
-- where @p :: Proxy sess@.
|
||||
class IsSessionData sess => RedisSession sess where
|
||||
-- | Transform a decomposed session into a Redis hash. Keys
|
||||
-- will be prepended with @\"data:\"@ before being stored.
|
||||
toHash :: Proxy sess -> Decomposed sess -> [(ByteString, ByteString)]
|
||||
|
||||
-- | Parse back a Redis hash into session data.
|
||||
fromHash :: Proxy sess -> [(ByteString, ByteString)] -> Decomposed sess
|
||||
|
||||
|
||||
-- | Assumes that keys are UTF-8 encoded when parsing (which is
|
||||
-- true if keys are always generated via @toHash@).
|
||||
instance RedisSession SessionMap where
|
||||
toHash _ = map (first TE.encodeUtf8) . M.toList . unSessionMap
|
||||
fromHash _ = SessionMap . M.fromList . map (first TE.decodeUtf8)
|
||||
|
||||
|
||||
-- | Parse a 'Session' from a Redis hash.
|
||||
parseSession :: SessionId -> [(ByteString, ByteString)] -> Maybe Session
|
||||
parseSession
|
||||
:: forall sess. RedisSession sess
|
||||
=> SessionId sess
|
||||
-> [(ByteString, ByteString)]
|
||||
-> Maybe (Session sess)
|
||||
parseSession _ [] = Nothing
|
||||
parseSession sid bss =
|
||||
let (externalList, internalList) = partition (B8.isPrefixOf "data:" . fst) bss
|
||||
@ -124,25 +159,26 @@ parseSession sid bss =
|
||||
accessedAt = parseUTCTime $ lookup' "internal:accessedAt"
|
||||
lookup' k = fromMaybe (error err) $ lookup k internalList
|
||||
where err = "serversession-backend-redis/parseSession: missing key " ++ show k
|
||||
sessionMap = M.fromList $ map (first $ TE.decodeUtf8 . removePrefix) externalList
|
||||
data_ = fromHash p $ map (first removePrefix) externalList
|
||||
where removePrefix bs = let ("data:", key) = B8.splitAt 5 bs in key
|
||||
p = Proxy :: Proxy sess
|
||||
in Just Session
|
||||
{ sessionKey = sid
|
||||
, sessionAuthId = authId
|
||||
, sessionData = sessionMap
|
||||
, sessionData = data_
|
||||
, sessionCreatedAt = createdAt
|
||||
, sessionAccessedAt = accessedAt
|
||||
}
|
||||
|
||||
|
||||
-- | Convert a 'Session' into a Redis hash.
|
||||
printSession :: Session -> [(ByteString, ByteString)]
|
||||
printSession :: forall sess. RedisSession sess => Session sess -> [(ByteString, ByteString)]
|
||||
printSession Session {..} =
|
||||
maybe id ((:) . (,) "internal:authId") sessionAuthId $
|
||||
(:) ("internal:createdAt", printUTCTime sessionCreatedAt) $
|
||||
(:) ("internal:accessedAt", printUTCTime sessionAccessedAt) $
|
||||
map (first $ B8.append "data:" . TE.encodeUtf8) $
|
||||
M.toList sessionData
|
||||
map (first $ B8.append "data:") $
|
||||
toHash (Proxy :: Proxy sess) sessionData
|
||||
|
||||
|
||||
-- | Parse 'UTCTime' from a 'ByteString' stored on Redis. Uses
|
||||
@ -177,12 +213,12 @@ batched f xs =
|
||||
|
||||
|
||||
-- | Get the session for the given session ID.
|
||||
getSessionImpl :: SessionId -> R.Redis (Maybe Session)
|
||||
getSessionImpl :: RedisSession sess => SessionId sess -> R.Redis (Maybe (Session sess))
|
||||
getSessionImpl sid = parseSession sid <$> unwrap (R.hgetall $ rSessionKey sid)
|
||||
|
||||
|
||||
-- | Delete the session with given session ID.
|
||||
deleteSessionImpl :: SessionId -> R.Redis ()
|
||||
deleteSessionImpl :: RedisSession sess => SessionId sess -> R.Redis ()
|
||||
deleteSessionImpl sid = do
|
||||
msession <- getSessionImpl sid
|
||||
case msession of
|
||||
@ -196,18 +232,22 @@ deleteSessionImpl sid = do
|
||||
|
||||
-- | Remove the given 'SessionId' from the set of sessions of the
|
||||
-- given 'AuthId'. Does not do anything if @Nothing@.
|
||||
removeSessionFromAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m ()
|
||||
removeSessionFromAuthId :: R.RedisCtx m f => SessionId sess -> Maybe AuthId -> m ()
|
||||
removeSessionFromAuthId = fooSessionBarAuthId R.srem
|
||||
|
||||
-- | Insert the given 'SessionId' into the set of sessions of the
|
||||
-- given 'AuthId'. Does not do anything if @Nothing@.
|
||||
insertSessionForAuthId :: R.RedisCtx m f => SessionId -> Maybe AuthId -> m ()
|
||||
insertSessionForAuthId :: R.RedisCtx m f => SessionId sess -> Maybe AuthId -> m ()
|
||||
insertSessionForAuthId = fooSessionBarAuthId R.sadd
|
||||
|
||||
|
||||
-- | (Internal) Helper for 'removeSessionFromAuthId' and 'insertSessionForAuthId'
|
||||
fooSessionBarAuthId
|
||||
:: R.RedisCtx m f => (ByteString -> [ByteString] -> m (f Integer)) -> SessionId -> Maybe AuthId -> m ()
|
||||
:: R.RedisCtx m f
|
||||
=> (ByteString -> [ByteString] -> m (f Integer))
|
||||
-> SessionId sess
|
||||
-> Maybe AuthId
|
||||
-> m ()
|
||||
fooSessionBarAuthId _ _ Nothing = return ()
|
||||
fooSessionBarAuthId fun sid (Just authId) = void $ fun (rAuthKey authId) [rSessionKey sid]
|
||||
|
||||
@ -220,13 +260,13 @@ deleteAllSessionsOfAuthIdImpl authId = do
|
||||
|
||||
|
||||
-- | Insert a new session.
|
||||
insertSessionImpl :: Session -> R.Redis ()
|
||||
insertSessionImpl :: RedisSession sess => Session sess -> R.Redis ()
|
||||
insertSessionImpl session = do
|
||||
-- Check that no old session exists.
|
||||
let sid = sessionKey session
|
||||
moldSession <- getSessionImpl sid
|
||||
case moldSession of
|
||||
Just oldSession -> liftIO $ E.throwIO $ SessionAlreadyExists oldSession session
|
||||
Just oldSession -> throwRS $ SessionAlreadyExists oldSession session
|
||||
Nothing -> do
|
||||
transaction $ do
|
||||
let sk = rSessionKey sid
|
||||
@ -237,13 +277,13 @@ insertSessionImpl session = do
|
||||
|
||||
|
||||
-- | Replace the contents of a session.
|
||||
replaceSessionImpl :: Session -> R.Redis ()
|
||||
replaceSessionImpl :: RedisSession sess => Session sess -> R.Redis ()
|
||||
replaceSessionImpl session = do
|
||||
-- Check that the old session exists.
|
||||
let sid = sessionKey session
|
||||
moldSession <- getSessionImpl sid
|
||||
case moldSession of
|
||||
Nothing -> liftIO $ E.throwIO $ SessionDoesNotExist session
|
||||
Nothing -> throwRS $ SessionDoesNotExist session
|
||||
Just oldSession -> do
|
||||
transaction $ do
|
||||
-- Delete the old session and set the new one.
|
||||
@ -259,3 +299,11 @@ replaceSessionImpl session = do
|
||||
insertSessionForAuthId sid newAuthId
|
||||
|
||||
return (() <$ r)
|
||||
|
||||
|
||||
-- | Specialization of 'E.throwIO' for 'RedisStorage'.
|
||||
throwRS
|
||||
:: Storage (RedisStorage sess)
|
||||
=> StorageException (RedisStorage sess)
|
||||
-> R.Redis a
|
||||
throwRS = liftIO . E.throwIO
|
||||
|
||||
@ -32,7 +32,10 @@ library
|
||||
Web.ServerSession.Frontend.Snap
|
||||
Web.ServerSession.Frontend.Snap.Internal
|
||||
extensions:
|
||||
FlexibleContexts
|
||||
OverloadedStrings
|
||||
TypeFamilies
|
||||
UndecidableInstances
|
||||
ghc-options: -Wall
|
||||
|
||||
source-repository head
|
||||
|
||||
@ -3,6 +3,7 @@ module Web.ServerSession.Frontend.Snap
|
||||
( -- * Using server-side sessions
|
||||
initServerSessionManager
|
||||
, simpleServerSessionManager
|
||||
, SnapSession(..)
|
||||
-- * Invalidating session IDs
|
||||
, forceInvalidate
|
||||
, ForceInvalidate(..)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
module Web.ServerSession.Frontend.Snap.Internal
|
||||
( initServerSessionManager
|
||||
, simpleServerSessionManager
|
||||
, SnapSession(..)
|
||||
, ServerSessionManager(..)
|
||||
, currentSessionMap
|
||||
, modifyCurrentSession
|
||||
@ -33,7 +34,10 @@ import qualified Snap.Snaplet.Session.SessionManager as S
|
||||
|
||||
|
||||
-- | Create a new 'ServerSessionManager' using the given 'State'.
|
||||
initServerSessionManager :: Storage s => IO (State s) -> S.SnapletInit b S.SessionManager
|
||||
initServerSessionManager
|
||||
:: (Storage sto, SnapSession (SessionData sto))
|
||||
=> IO (State sto)
|
||||
-> S.SnapletInit b S.SessionManager
|
||||
initServerSessionManager mkState =
|
||||
S.makeSnaplet "ServerSession"
|
||||
"Snaplet providing sessions via server-side storage."
|
||||
@ -51,17 +55,67 @@ initServerSessionManager mkState =
|
||||
|
||||
-- | Simplified version of 'initServerSessionManager', sufficient
|
||||
-- for most needs.
|
||||
simpleServerSessionManager :: Storage s => IO s -> (State s -> State s) -> S.SnapletInit b S.SessionManager
|
||||
simpleServerSessionManager
|
||||
:: (Storage sto, SessionData sto ~ SessionMap)
|
||||
=> IO sto
|
||||
-> (State sto -> State sto)
|
||||
-> S.SnapletInit b S.SessionManager
|
||||
simpleServerSessionManager mkStorage opts =
|
||||
initServerSessionManager (fmap opts . createState =<< mkStorage)
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Class for data types that implement the operations Snap
|
||||
-- expects sessions to support.
|
||||
class IsSessionData sess => SnapSession sess where
|
||||
ssInsert :: Text -> Text -> sess -> sess
|
||||
ssLookup :: Text -> sess -> Maybe Text
|
||||
ssDelete :: Text -> sess -> sess
|
||||
ssToList :: sess -> [(Text, Text)]
|
||||
|
||||
ssInsertCsrf :: Text -> sess -> sess
|
||||
ssLookupCsrf :: sess -> Maybe Text
|
||||
|
||||
ssForceInvalidate :: ForceInvalidate -> sess -> sess
|
||||
|
||||
|
||||
-- | Uses 'csrfKey'.
|
||||
instance SnapSession SessionMap where
|
||||
ssInsert key val = onSM (M.insert key (TE.encodeUtf8 val))
|
||||
ssLookup key = fmap TE.decodeUtf8 . M.lookup key . unSessionMap
|
||||
ssDelete key = onSM (M.delete key)
|
||||
ssToList =
|
||||
-- Remove the CSRF key from the list as the current
|
||||
-- clientsession backend doesn't return it.
|
||||
fmap (second TE.decodeUtf8) .
|
||||
M.toList .
|
||||
M.delete csrfKey .
|
||||
unSessionMap
|
||||
|
||||
ssInsertCsrf = ssInsert csrfKey
|
||||
ssLookupCsrf = ssLookup csrfKey
|
||||
|
||||
ssForceInvalidate force = onSM (M.insert forceInvalidateKey (B8.pack $ show force))
|
||||
|
||||
|
||||
-- | Apply a function to a 'SessionMap'.
|
||||
onSM
|
||||
:: (M.Map Text ByteString -> M.Map Text ByteString)
|
||||
-> (SessionMap -> SessionMap)
|
||||
onSM f = SessionMap . f . unSessionMap
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | A 'S.ISessionManager' using server-side sessions.
|
||||
data ServerSessionManager s =
|
||||
data ServerSessionManager sto =
|
||||
ServerSessionManager
|
||||
{ currentSession :: Maybe (SessionMap, SaveSessionToken)
|
||||
{ currentSession :: Maybe (SessionData sto, SaveSessionToken sto)
|
||||
-- ^ Field used for per-request caching of the session.
|
||||
, state :: State s
|
||||
, state :: State sto
|
||||
-- ^ The core @serversession@ state.
|
||||
, cookieName :: ByteString
|
||||
-- ^ Cache of the cookie name as bytestring.
|
||||
@ -70,26 +124,32 @@ data ServerSessionManager s =
|
||||
} deriving (Typeable)
|
||||
|
||||
|
||||
instance Storage s => S.ISessionManager (ServerSessionManager s) where
|
||||
instance ( Storage sto
|
||||
, SnapSession (SessionData sto)
|
||||
) => S.ISessionManager (ServerSessionManager sto) where
|
||||
load ssm@ServerSessionManager { currentSession = Just _ } =
|
||||
-- Don't do anything if already loaded. Yeah, I know this is
|
||||
-- strange, go figure.
|
||||
return ssm
|
||||
load ssm = do
|
||||
-- Get session ID from cookie.
|
||||
mcookie <- S.getCookie (cookieName ssm)
|
||||
-- Load session from storage backend.
|
||||
(sessionMap, saveSessionToken) <-
|
||||
(data1, saveSessionToken) <-
|
||||
liftIO $ loadSession (state ssm) (S.cookieValue <$> mcookie)
|
||||
-- Add CSRF token if needed.
|
||||
sessionMap' <-
|
||||
data2 <-
|
||||
maybe
|
||||
(flip (M.insert csrfKey) sessionMap <$> N.nonce128url (nonceGen ssm))
|
||||
(const $ return sessionMap)
|
||||
(M.lookup csrfKey sessionMap)
|
||||
(flip ssInsertCsrf data1 <$> N.nonce128urlT (nonceGen ssm))
|
||||
(const $ return data1)
|
||||
(ssLookupCsrf data1)
|
||||
-- Good to go!
|
||||
return ssm { currentSession = Just (sessionMap', saveSessionToken) }
|
||||
return ssm { currentSession = Just (data2, saveSessionToken) }
|
||||
|
||||
commit ssm = do
|
||||
-- Save session data to storage backend and set the cookie.
|
||||
let Just (sessionMap, saveSessionToken) = currentSession ssm
|
||||
msession <- liftIO $ saveSession (state ssm) saveSessionToken sessionMap
|
||||
let Just (data_, saveSessionToken) = currentSession ssm
|
||||
msession <- liftIO $ saveSession (state ssm) saveSessionToken data_
|
||||
S.modifyResponse $ S.addResponseCookie $
|
||||
maybe
|
||||
(deleteCookie (state ssm) (cookieName ssm))
|
||||
@ -100,60 +160,62 @@ instance Storage s => S.ISessionManager (ServerSessionManager s) where
|
||||
-- Reset has no defined semantics. We invalidate the session
|
||||
-- and clear its variables, which seems to be what the
|
||||
-- current clientsession backend from the snap package does.
|
||||
csrfToken <- N.nonce128url (nonceGen ssm)
|
||||
let newSession = M.fromList [ (forceInvalidateKey, B8.pack $ show CurrentSessionId)
|
||||
, (csrfKey, csrfToken) ]
|
||||
csrfToken <- N.nonce128urlT (nonceGen ssm)
|
||||
let newSession =
|
||||
ssInsertCsrf csrfToken $
|
||||
ssForceInvalidate CurrentSessionId $
|
||||
emptySession
|
||||
return $ modifyCurrentSession (const newSession) ssm
|
||||
|
||||
touch =
|
||||
-- We always touch the session (if commit is called).
|
||||
id
|
||||
|
||||
insert key value = modifyCurrentSession (M.insert key (TE.encodeUtf8 value))
|
||||
insert key value = modifyCurrentSession (ssInsert key value)
|
||||
|
||||
lookup key =
|
||||
-- Decoding will always succeed if the session is used only
|
||||
-- from snap.
|
||||
fmap TE.decodeUtf8 . M.lookup key . currentSessionMap "lookup"
|
||||
ssLookup key . currentSessionMap "lookup"
|
||||
|
||||
delete key = modifyCurrentSession (M.delete key)
|
||||
delete key = modifyCurrentSession (ssDelete key)
|
||||
|
||||
csrf =
|
||||
-- Guaranteed to succeed since both load and reset add a
|
||||
-- csrfKey to the session map.
|
||||
fromMaybe (error "serversession-frontend-snap/csrf: never here") .
|
||||
S.lookup csrfKey
|
||||
ssLookupCsrf . currentSessionMap "csrf"
|
||||
|
||||
toList =
|
||||
-- Remove the CSRF key from the list as the current
|
||||
-- clientsession backend doesn't return it.
|
||||
fmap (second TE.decodeUtf8) .
|
||||
M.toList .
|
||||
M.delete csrfKey .
|
||||
currentSessionMap "toList"
|
||||
toList = ssToList . currentSessionMap "toList"
|
||||
|
||||
|
||||
-- | Get the current 'SessionMap' from 'currentSession' and
|
||||
-- | Get the current 'SessionData' from 'currentSession' and
|
||||
-- unwrap its @Just@. If it's @Nothing@, @error@ is called. We
|
||||
-- expect 'load' to be called before any other 'ISessionManager'
|
||||
-- method.
|
||||
currentSessionMap :: String -> ServerSessionManager s -> SessionMap
|
||||
currentSessionMap :: String -> ServerSessionManager sto -> SessionData sto
|
||||
currentSessionMap fn ssm = maybe (error err) fst (currentSession ssm)
|
||||
where err = "serversession-frontend-snap/" ++ fn ++
|
||||
": currentSession is Nothing, did you call 'load'?"
|
||||
|
||||
|
||||
-- | Modify the current session in any way.
|
||||
modifyCurrentSession :: (SessionMap -> SessionMap) -> ServerSessionManager s -> ServerSessionManager s
|
||||
modifyCurrentSession
|
||||
:: (SessionData sto -> SessionData sto)
|
||||
-> ServerSessionManager sto
|
||||
-> ServerSessionManager sto
|
||||
modifyCurrentSession f ssm = ssm { currentSession = fmap (first f) (currentSession ssm) }
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Create a cookie for the given session.
|
||||
--
|
||||
-- The cookie expiration is set via 'nextExpires'. Note that
|
||||
-- this is just an optimization, as the expiration is checked on
|
||||
-- the server-side as well.
|
||||
createCookie :: State s -> ByteString -> Session -> S.Cookie
|
||||
createCookie :: State sto -> ByteString -> Session sess -> S.Cookie
|
||||
createCookie st cookieNameBS session =
|
||||
-- Generate a cookie with the final session ID.
|
||||
S.Cookie
|
||||
@ -176,7 +238,7 @@ createCookie st cookieNameBS session =
|
||||
-- * If the user had a session cookie that was invalidated,
|
||||
-- this will remove the invalid cookie from the client.
|
||||
-- the server-side as well.
|
||||
deleteCookie :: State s -> ByteString -> S.Cookie
|
||||
deleteCookie :: State sto -> ByteString -> S.Cookie
|
||||
deleteCookie st cookieNameBS =
|
||||
S.Cookie
|
||||
{ S.cookieName = cookieNameBS
|
||||
|
||||
@ -34,7 +34,9 @@ library
|
||||
Web.ServerSession.Frontend.Wai
|
||||
Web.ServerSession.Frontend.Wai.Internal
|
||||
extensions:
|
||||
FlexibleContexts
|
||||
OverloadedStrings
|
||||
TypeFamilies
|
||||
ghc-options: -Wall
|
||||
|
||||
source-repository head
|
||||
|
||||
@ -17,6 +17,7 @@ module Web.ServerSession.Frontend.Wai
|
||||
-- * Flexible interface
|
||||
, sessionStore
|
||||
, createCookieTemplate
|
||||
, KeyValue(..)
|
||||
-- * State configuration
|
||||
, setCookieName
|
||||
, setAuthKey
|
||||
|
||||
@ -4,6 +4,7 @@ module Web.ServerSession.Frontend.Wai.Internal
|
||||
( withServerSession
|
||||
, sessionStore
|
||||
, mkSession
|
||||
, KeyValue(..)
|
||||
, createCookieTemplate
|
||||
, calculateMaxAge
|
||||
, forceInvalidate
|
||||
@ -34,10 +35,10 @@ import qualified Web.Cookie as C
|
||||
-- that uses 'WS.withSession', 'createState', 'sessionStore',
|
||||
-- 'getCookieName' and 'createCookieTemplate'.
|
||||
withServerSession
|
||||
:: (MonadIO m, MonadIO n, Storage s)
|
||||
:: (MonadIO m, MonadIO n, Storage sto, SessionData sto ~ SessionMap)
|
||||
=> V.Key (WS.Session m Text ByteString) -- ^ 'V.Vault' key to use when passing the session through.
|
||||
-> (State s -> State s) -- ^ Set any options on the @serversession@ state.
|
||||
-> s -- ^ Storage backend.
|
||||
-> (State sto -> State sto) -- ^ Set any options on the @serversession@ state.
|
||||
-> sto -- ^ Storage backend.
|
||||
-> n W.Middleware
|
||||
withServerSession key opts storage = liftIO $ do
|
||||
st <- opts <$> createState storage
|
||||
@ -56,32 +57,55 @@ withServerSession key opts storage = liftIO $ do
|
||||
-- return an empty @ByteString@ when the empty session was not
|
||||
-- saved.
|
||||
sessionStore
|
||||
:: (MonadIO m, Storage s)
|
||||
=> State s -- ^ @serversession@ state, incl. storage backend.
|
||||
-> WS.SessionStore m Text ByteString -- ^ @wai-session@ session store.
|
||||
:: (MonadIO m, Storage sto, KeyValue (SessionData sto))
|
||||
=> State sto -- ^ @serversession@ state, incl. storage backend.
|
||||
-> WS.SessionStore m (Key (SessionData sto)) (Value (SessionData sto))
|
||||
-- ^ @wai-session@ session store.
|
||||
sessionStore state =
|
||||
\mcookieVal -> do
|
||||
(sessionMap, saveSessionToken) <- loadSession state mcookieVal
|
||||
sessionRef <- I.newIORef sessionMap
|
||||
(data1, saveSessionToken) <- loadSession state mcookieVal
|
||||
sessionRef <- I.newIORef data1
|
||||
let save = do
|
||||
sessionMap' <- I.atomicModifyIORef' sessionRef $ \a -> (a, a)
|
||||
msession <- saveSession state saveSessionToken sessionMap'
|
||||
data2 <- I.atomicModifyIORef' sessionRef $ \a -> (a, a)
|
||||
msession <- saveSession state saveSessionToken data2
|
||||
return $ maybe "" (TE.encodeUtf8 . toPathPiece . sessionKey) msession
|
||||
return (mkSession sessionRef, save)
|
||||
|
||||
|
||||
-- | Build a 'WS.Session' from an 'I.IORef' containing the
|
||||
-- session data.
|
||||
mkSession :: MonadIO m => I.IORef SessionMap -> WS.Session m Text ByteString
|
||||
mkSession :: (MonadIO m, KeyValue sess) => I.IORef sess -> WS.Session m (Key sess) (Value sess)
|
||||
mkSession sessionRef =
|
||||
-- We need to use atomicModifyIORef instead of readIORef
|
||||
-- because latter may be reordered (cf. "Memory Model" on
|
||||
-- Data.IORef's documentation).
|
||||
( \k -> M.lookup k <$> liftIO (I.atomicModifyIORef' sessionRef $ \a -> (a, a))
|
||||
, \k v -> liftIO (I.atomicModifyIORef' sessionRef $ flip (,) () . M.insert k v)
|
||||
( \k -> kvLookup k <$> liftIO (I.atomicModifyIORef' sessionRef $ \a -> (a, a))
|
||||
, \k v -> liftIO (I.atomicModifyIORef' sessionRef $ flip (,) () . kvInsert k v)
|
||||
)
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Class for session data types that can be used as key-value
|
||||
-- stores.
|
||||
class IsSessionData sess => KeyValue sess where
|
||||
type Key sess :: *
|
||||
type Value sess :: *
|
||||
kvLookup :: Key sess -> sess -> Maybe (Value sess)
|
||||
kvInsert :: Key sess -> Value sess -> sess -> sess
|
||||
|
||||
|
||||
instance KeyValue SessionMap where
|
||||
type Key SessionMap = Text
|
||||
type Value SessionMap = ByteString
|
||||
kvLookup k = M.lookup k . unSessionMap
|
||||
kvInsert k v (SessionMap m) = SessionMap (M.insert k v m)
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Create a cookie template given a state.
|
||||
--
|
||||
-- Since we don't have access to the 'Session', we can't fill the
|
||||
@ -94,7 +118,7 @@ mkSession sessionRef =
|
||||
-- Instead, we fill only the @Max-age@ field. It works fine for
|
||||
-- modern browsers, but many don't support it and will treat the
|
||||
-- cookie as non-persistent (notably IE 6, 7 and 8).
|
||||
createCookieTemplate :: State s -> C.SetCookie
|
||||
createCookieTemplate :: State sto -> C.SetCookie
|
||||
createCookieTemplate state =
|
||||
-- Generate a cookie with the final session ID.
|
||||
def
|
||||
@ -115,7 +139,7 @@ createCookieTemplate state =
|
||||
-- * If no timeout is defined, the result is 10 years.
|
||||
--
|
||||
-- * Otherwise, the max age is set as the maximum timeout.
|
||||
calculateMaxAge :: State s -> Maybe TI.DiffTime
|
||||
calculateMaxAge :: State sto -> Maybe TI.DiffTime
|
||||
calculateMaxAge st = do
|
||||
guard (persistentCookies st)
|
||||
return $ maybe (60*60*24*3652) realToFrac
|
||||
|
||||
@ -18,6 +18,7 @@ library
|
||||
build-depends:
|
||||
base == 4.*
|
||||
, bytestring
|
||||
, containers
|
||||
, cookie >= 0.4
|
||||
, data-default
|
||||
, path-pieces
|
||||
@ -32,7 +33,9 @@ library
|
||||
Web.ServerSession.Frontend.Yesod
|
||||
Web.ServerSession.Frontend.Yesod.Internal
|
||||
extensions:
|
||||
FlexibleContexts
|
||||
OverloadedStrings
|
||||
TypeFamilies
|
||||
ghc-options: -Wall
|
||||
|
||||
source-repository head
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
-- | Yesod server-side session support.
|
||||
--
|
||||
-- This package implements an Yesod @SessionBackend@, so it's a
|
||||
-- drop-in replacement for the default @clientsession@.
|
||||
--
|
||||
-- Unfortunately, Yesod currently provides no way of accessing
|
||||
-- the session other than via its own functions. If you want to
|
||||
-- use a custom data type as your session data (instead of the
|
||||
-- default @SessionMap@), it will have to implement
|
||||
-- 'IsSessionMap' and you'll have to continue using Yesod's
|
||||
-- session interface.
|
||||
module Web.ServerSession.Frontend.Yesod
|
||||
( -- * Using server-side sessions
|
||||
simpleBackend
|
||||
, backend
|
||||
, IsSessionMap(..)
|
||||
-- * Invalidating session IDs
|
||||
, forceInvalidate
|
||||
, ForceInvalidate(..)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
module Web.ServerSession.Frontend.Yesod.Internal
|
||||
( simpleBackend
|
||||
, backend
|
||||
, IsSessionMap(..)
|
||||
, createCookie
|
||||
, findSessionId
|
||||
, forceInvalidate
|
||||
@ -12,6 +13,7 @@ import Control.Monad (guard)
|
||||
import Control.Monad.IO.Class (MonadIO)
|
||||
import Data.ByteString (ByteString)
|
||||
import Data.Default (def)
|
||||
import Data.Text (Text)
|
||||
import Web.PathPieces (toPathPiece)
|
||||
import Web.ServerSession.Core
|
||||
import Yesod.Core (MonadHandler)
|
||||
@ -19,6 +21,7 @@ import Yesod.Core.Handler (setSessionBS)
|
||||
import Yesod.Core.Types (Header(AddCookie), SessionBackend(..))
|
||||
|
||||
import qualified Data.ByteString.Char8 as B8
|
||||
import qualified Data.Map as M
|
||||
import qualified Data.Text.Encoding as TE
|
||||
import qualified Data.Time as TI
|
||||
import qualified Network.Wai as W
|
||||
@ -53,42 +56,70 @@ import qualified Web.Cookie as C
|
||||
-- . setAbsoluteTimeout (Just $ 60*60*24)
|
||||
-- . setSecureCookies True
|
||||
-- @
|
||||
--
|
||||
-- This is a simple version of 'backend' specialized for using
|
||||
-- 'SessionMap' as 'SessionData'. If you want to use a different
|
||||
-- session data type, please use 'backend' directly (tip: take a
|
||||
-- peek at this function's source).
|
||||
simpleBackend
|
||||
:: (MonadIO m, Storage s)
|
||||
=> (State s -> State s) -- ^ Set any options on the @serversession@ state.
|
||||
-> s -- ^ Storage backend.
|
||||
:: (MonadIO m, Storage sto, SessionData sto ~ SessionMap)
|
||||
=> (State sto -> State sto) -- ^ Set any options on the @serversession@ state.
|
||||
-> sto -- ^ Storage backend.
|
||||
-> m (Maybe SessionBackend) -- ^ Yesod session backend (always @Just@).
|
||||
simpleBackend opts s =
|
||||
return . Just . backend . opts =<< createState s
|
||||
|
||||
|
||||
-- | Construct the server-side session backend using the given
|
||||
-- state.
|
||||
-- state. This is a generalized version of 'simpleBackend'.
|
||||
--
|
||||
-- In order to use the Yesod frontend, you 'SessionData' needs to
|
||||
-- implement 'IsSessionMap'.
|
||||
backend
|
||||
:: Storage s
|
||||
=> State s -- ^ @serversession@ state, incl. storage backend.
|
||||
:: (Storage sto, IsSessionMap (SessionData sto))
|
||||
=> State sto -- ^ @serversession@ state, incl. storage backend.
|
||||
-> SessionBackend -- ^ Yesod session backend.
|
||||
backend state =
|
||||
SessionBackend {
|
||||
sbLoadSession = \req -> do
|
||||
backend state = SessionBackend { sbLoadSession = load }
|
||||
where
|
||||
load req = do
|
||||
let rawSessionId = findSessionId cookieNameBS req
|
||||
(sessionMap, saveSessionToken) <- loadSession state rawSessionId
|
||||
(data_, saveSessionToken) <- loadSession state rawSessionId
|
||||
let save =
|
||||
fmap ((:[]) . maybe (deleteCookie state cookieNameBS)
|
||||
(createCookie state cookieNameBS)) .
|
||||
saveSession state saveSessionToken
|
||||
return (sessionMap, save)
|
||||
}
|
||||
where
|
||||
saveSession state saveSessionToken .
|
||||
fromSessionMap
|
||||
return (toSessionMap data_, save)
|
||||
|
||||
cookieNameBS = TE.encodeUtf8 $ getCookieName state
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Class for session data types meant to be used with the Yesod
|
||||
-- frontend. The only session interface Yesod provides is via
|
||||
-- session variables, so your data type needs to be convertible
|
||||
-- from/to a 'M.Map' of 'Text' to 'ByteString'.
|
||||
class IsSessionMap sess where
|
||||
toSessionMap :: sess -> M.Map Text ByteString
|
||||
fromSessionMap :: M.Map Text ByteString -> sess
|
||||
|
||||
|
||||
instance IsSessionMap SessionMap where
|
||||
toSessionMap = unSessionMap
|
||||
fromSessionMap = SessionMap
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Create a cookie for the given session.
|
||||
--
|
||||
-- The cookie expiration is set via 'nextExpires'. Note that
|
||||
-- this is just an optimization, as the expiration is checked on
|
||||
-- the server-side as well.
|
||||
createCookie :: State s -> ByteString -> Session -> Header
|
||||
createCookie :: State sto -> ByteString -> Session sess -> Header
|
||||
createCookie state cookieNameBS session =
|
||||
-- Generate a cookie with the final session ID.
|
||||
AddCookie def
|
||||
@ -110,7 +141,7 @@ createCookie state cookieNameBS session =
|
||||
--
|
||||
-- * If the user had a session cookie that was invalidated,
|
||||
-- this will remove the invalid cookie from the client.
|
||||
deleteCookie :: State s -> ByteString -> Header
|
||||
deleteCookie :: State sto -> ByteString -> Header
|
||||
deleteCookie state cookieNameBS =
|
||||
AddCookie def
|
||||
{ C.setCookieName = cookieNameBS
|
||||
|
||||
@ -37,7 +37,9 @@ library
|
||||
OverloadedStrings
|
||||
RecordWildCards
|
||||
ScopedTypeVariables
|
||||
StandaloneDeriving
|
||||
TypeFamilies
|
||||
UndecidableInstances
|
||||
ghc-options: -Wall
|
||||
|
||||
|
||||
@ -53,9 +55,12 @@ test-suite tests
|
||||
, serversession
|
||||
extensions:
|
||||
DeriveDataTypeable
|
||||
FlexibleContexts
|
||||
OverloadedStrings
|
||||
StandaloneDeriving
|
||||
TupleSections
|
||||
TypeFamilies
|
||||
UndecidableInstances
|
||||
main-is: Main.hs
|
||||
ghc-options: -Wall -threaded -with-rtsopts=-N
|
||||
|
||||
|
||||
@ -6,9 +6,11 @@ module Web.ServerSession.Core
|
||||
, Session(..)
|
||||
, Storage(..)
|
||||
, StorageException(..)
|
||||
, IsSessionData(..)
|
||||
, DecomposedSession(..)
|
||||
|
||||
-- * For serversession frontends
|
||||
, SessionMap
|
||||
, SessionMap(..)
|
||||
, State
|
||||
, createState
|
||||
, getCookieName
|
||||
|
||||
@ -1,13 +1,21 @@
|
||||
-- | Internal module exposing the guts of the package. Use at
|
||||
-- your own risk. No API stability guarantees apply.
|
||||
--
|
||||
-- @UndecidableInstances@ is required in order to implement @Eq@,
|
||||
-- @Ord@, @Show@, etc. on data types that have @Decomposed@
|
||||
-- fields, and should be fairly safe.
|
||||
module Web.ServerSession.Core.Internal
|
||||
( SessionId(..)
|
||||
, checkSessionId
|
||||
, generateSessionId
|
||||
|
||||
, SessionMap
|
||||
, AuthId
|
||||
, Session(..)
|
||||
, SessionMap(..)
|
||||
|
||||
, IsSessionData(..)
|
||||
, DecomposedSession(..)
|
||||
|
||||
, Storage(..)
|
||||
, StorageException(..)
|
||||
|
||||
@ -32,10 +40,7 @@ module Web.ServerSession.Core.Internal
|
||||
, saveSession
|
||||
, SaveSessionToken(..)
|
||||
, invalidateIfNeeded
|
||||
, DecomposedSession(..)
|
||||
, decomposeSession
|
||||
, saveSessionOnDb
|
||||
, toSessionMap
|
||||
, forceInvalidateKey
|
||||
, ForceInvalidate(..)
|
||||
) where
|
||||
@ -64,7 +69,8 @@ import qualified Data.Text.Encoding as TE
|
||||
|
||||
|
||||
-- | The ID of a session. Always 18 bytes base64url-encoded as
|
||||
-- 24 characters.
|
||||
-- 24 characters. The @sess@ type variable is a phantom type for
|
||||
-- the session data type this session ID points to.
|
||||
--
|
||||
-- Implementation notes:
|
||||
--
|
||||
@ -72,24 +78,24 @@ import qualified Data.Text.Encoding as TE
|
||||
--
|
||||
-- * Use 'generateSessionId' for securely generating new
|
||||
-- session IDs.
|
||||
newtype SessionId = S { unS :: Text }
|
||||
newtype SessionId sess = S { unS :: Text }
|
||||
deriving (Eq, Ord, Show, Read, Typeable)
|
||||
|
||||
-- | Sanity checks input on 'fromPathPiece' (untrusted input).
|
||||
instance PathPiece SessionId where
|
||||
instance PathPiece (SessionId sess) where
|
||||
toPathPiece = unS
|
||||
fromPathPiece = checkSessionId
|
||||
|
||||
instance A.FromJSON SessionId where
|
||||
instance A.FromJSON (SessionId sess) where
|
||||
parseJSON = fmap S . A.parseJSON
|
||||
|
||||
instance A.ToJSON SessionId where
|
||||
instance A.ToJSON (SessionId sess) where
|
||||
toJSON = A.toJSON . unS
|
||||
|
||||
|
||||
-- | (Internal) Check that the given text is a base64url-encoded
|
||||
-- representation of 18 bytes.
|
||||
checkSessionId :: Text -> Maybe SessionId
|
||||
checkSessionId :: Text -> Maybe (SessionId sess)
|
||||
checkSessionId text = do
|
||||
guard (T.length text == 24)
|
||||
let bs = TE.encodeUtf8 text
|
||||
@ -99,23 +105,13 @@ checkSessionId text = do
|
||||
|
||||
|
||||
-- | Securely generate a new SessionId.
|
||||
generateSessionId :: N.Generator -> IO SessionId
|
||||
generateSessionId :: N.Generator -> IO (SessionId sess)
|
||||
generateSessionId = fmap S . N.nonce128urlT
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | A session map.
|
||||
--
|
||||
-- This is the representation of a session used by the
|
||||
-- @serversession@ family of packages, transferring data between
|
||||
-- this core package and frontend packages. Serversession
|
||||
-- storage backend packages should use 'Session'. End users
|
||||
-- should use their web framework's support for sessions.
|
||||
type SessionMap = M.Map Text ByteString
|
||||
|
||||
|
||||
-- | Value of the 'authKey' session key.
|
||||
type AuthId = ByteString
|
||||
|
||||
@ -124,49 +120,179 @@ type AuthId = ByteString
|
||||
--
|
||||
-- This representation is used by the @serversession@ family of
|
||||
-- packages, transferring data between this core package and
|
||||
-- storage backend packages. Serversession frontend packages
|
||||
-- should use 'SessionMap'. End users should use their web
|
||||
-- framework's support for sessions.
|
||||
data Session =
|
||||
-- storage backend packages. The @sess@ type variable describes
|
||||
-- the session data type.
|
||||
data Session sess =
|
||||
Session
|
||||
{ sessionKey :: SessionId
|
||||
{ sessionKey :: SessionId sess
|
||||
-- ^ Session ID, primary key.
|
||||
, sessionAuthId :: Maybe AuthId
|
||||
-- ^ Value of 'authKey' session key, separate from the rest.
|
||||
, sessionData :: SessionMap
|
||||
, sessionData :: Decomposed sess
|
||||
-- ^ Rest of the session data.
|
||||
, sessionCreatedAt :: UTCTime
|
||||
-- ^ When this session was created.
|
||||
, sessionAccessedAt :: UTCTime
|
||||
-- ^ When this session was last accessed.
|
||||
} deriving (Eq, Ord, Show, Typeable)
|
||||
} deriving (Typeable)
|
||||
|
||||
deriving instance Eq (Decomposed sess) => Eq (Session sess)
|
||||
deriving instance Ord (Decomposed sess) => Ord (Session sess)
|
||||
deriving instance Show (Decomposed sess) => Show (Session sess)
|
||||
|
||||
|
||||
-- | A storage backend for server-side sessions.
|
||||
class MonadIO (TransactionM s) => Storage s where
|
||||
-- | A @newtype@ for a common session map.
|
||||
--
|
||||
-- This is a common representation of a session. Although
|
||||
-- @serversession@ has generalized session data types, you can
|
||||
-- use this one if you don't want to worry about it. We strive
|
||||
-- to support this session data type on all frontends and storage
|
||||
-- backends.
|
||||
newtype SessionMap =
|
||||
SessionMap { unSessionMap :: M.Map Text ByteString }
|
||||
deriving (Eq, Ord, Show, Read, Typeable)
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | Class for data types to be used as session data
|
||||
-- (cf. 'sessionData', 'SessionData').
|
||||
--
|
||||
-- The @Show@ constrain is needed for 'StorageException'.
|
||||
class ( Show (Decomposed sess)
|
||||
, Typeable (Decomposed sess)
|
||||
, Typeable sess
|
||||
) => IsSessionData sess where
|
||||
-- | The type of the session data after being decomposed. This
|
||||
-- may be the same as @sess@.
|
||||
type Decomposed sess :: *
|
||||
|
||||
-- | Empty session data.
|
||||
emptySession :: sess
|
||||
|
||||
-- | Decompose session data into:
|
||||
--
|
||||
-- * The auth ID of the logged in user (cf. 'setAuthKey',
|
||||
-- 'dsAuthId').
|
||||
--
|
||||
-- * If the session is being forced to be invalidated
|
||||
-- (cf. 'forceInvalidateKey', 'ForceInvalidate').
|
||||
--
|
||||
-- * The rest of the session data (cf. 'Decomposed').
|
||||
decomposeSession
|
||||
:: Text -- ^ The auth key (cf. 'setAuthKey').
|
||||
-> sess -- ^ Session data to be decomposed.
|
||||
-> DecomposedSession sess -- ^ Decomposed session data.
|
||||
|
||||
-- | Recompose a decomposed session again into a proper @sess@.
|
||||
recomposeSession
|
||||
:: Text -- ^ The auth key (cf. 'setAuthKey').
|
||||
-> Maybe AuthId -- ^ The @AuthId@, if any.
|
||||
-> Decomposed sess -- ^ Decomposed session data to be recomposed.
|
||||
-> sess -- ^ Recomposed session data.
|
||||
|
||||
-- | Returns @True@ when both session datas are to be
|
||||
-- considered the same.
|
||||
--
|
||||
-- This is used to optimize storage calls
|
||||
-- (cf. 'setTimeoutResolution'). Always returning @False@ will
|
||||
-- disable the optimization but won't have any other adverse
|
||||
-- effects.
|
||||
--
|
||||
-- For data types implementing 'Eq', this is usually a good
|
||||
-- implementation:
|
||||
--
|
||||
-- @
|
||||
-- isSameDecomposed _ = (==)
|
||||
-- @
|
||||
isSameDecomposed :: proxy sess -> Decomposed sess -> Decomposed sess -> Bool
|
||||
|
||||
-- | Returns @True@ if the decomposed session data is to be
|
||||
-- considered @empty@.
|
||||
--
|
||||
-- This is used to avoid storing empty session data if at all
|
||||
-- possible. Always returning @False@ will disable the
|
||||
-- optimization but won't have any other adverse effects.
|
||||
isDecomposedEmpty :: proxy sess -> Decomposed sess -> Bool
|
||||
|
||||
|
||||
-- | A 'SessionMap' decomposes into a 'SessionMap' minus the keys
|
||||
-- that were removed. The auth key is added back when
|
||||
-- recomposing.
|
||||
instance IsSessionData SessionMap where
|
||||
type Decomposed SessionMap = SessionMap
|
||||
|
||||
emptySession = SessionMap M.empty
|
||||
|
||||
isSameDecomposed _ = (==)
|
||||
|
||||
decomposeSession authKey_ (SessionMap sm1) =
|
||||
let (authId, sm2) = M.updateLookupWithKey (\_ _ -> Nothing) authKey_ sm1
|
||||
(force, sm3) = M.updateLookupWithKey (\_ _ -> Nothing) forceInvalidateKey sm2
|
||||
in DecomposedSession
|
||||
{ dsAuthId = authId
|
||||
, dsForceInvalidate = maybe DoNotForceInvalidate (read . B8.unpack) force
|
||||
, dsDecomposed = SessionMap sm3 }
|
||||
|
||||
recomposeSession authKey_ mauthId (SessionMap sm) =
|
||||
SessionMap $ maybe id (M.insert authKey_) mauthId sm
|
||||
|
||||
isDecomposedEmpty _ = M.null . unSessionMap
|
||||
|
||||
|
||||
-- | A session data type @sess@ with its special variables taken apart.
|
||||
data DecomposedSession sess =
|
||||
DecomposedSession
|
||||
{ dsAuthId :: !(Maybe ByteString)
|
||||
, dsForceInvalidate :: !ForceInvalidate
|
||||
, dsDecomposed :: !(Decomposed sess)
|
||||
} deriving (Typeable)
|
||||
|
||||
deriving instance Eq (Decomposed sess) => Eq (DecomposedSession sess)
|
||||
deriving instance Ord (Decomposed sess) => Ord (DecomposedSession sess)
|
||||
deriving instance Show (Decomposed sess) => Show (DecomposedSession sess)
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | A storage backend @sto@ for server-side sessions. The
|
||||
-- @sess@ session data type and\/or its 'Decomposed' version may
|
||||
-- be constrained depending on the storage backend capabilities.
|
||||
class ( Typeable sto
|
||||
, MonadIO (TransactionM sto)
|
||||
, IsSessionData (SessionData sto)
|
||||
) => Storage sto where
|
||||
-- | The session data type used by this storage.
|
||||
type SessionData sto :: *
|
||||
|
||||
-- | Monad where transactions happen for this backend.
|
||||
-- We do not require transactions to be ACID.
|
||||
type TransactionM s :: * -> *
|
||||
type TransactionM sto :: * -> *
|
||||
|
||||
-- | Run a transaction on the IO monad.
|
||||
runTransactionM :: s -> TransactionM s a -> IO a
|
||||
runTransactionM :: sto -> TransactionM sto a -> IO a
|
||||
|
||||
-- | Get the session for the given session ID. Returns
|
||||
-- @Nothing@ if the session is not found.
|
||||
getSession :: s -> SessionId -> TransactionM s (Maybe Session)
|
||||
getSession
|
||||
:: sto
|
||||
-> SessionId (SessionData sto)
|
||||
-> TransactionM sto (Maybe (Session (SessionData sto)))
|
||||
|
||||
-- | Delete the session with given session ID. Does not do
|
||||
-- anything if the session is not found.
|
||||
deleteSession :: s -> SessionId -> TransactionM s ()
|
||||
deleteSession :: sto -> SessionId (SessionData sto) -> TransactionM sto ()
|
||||
|
||||
-- | Delete all sessions of the given auth ID. Does not do
|
||||
-- anything if there are no sessions of the given auth ID.
|
||||
deleteAllSessionsOfAuthId :: s -> AuthId -> TransactionM s ()
|
||||
deleteAllSessionsOfAuthId :: sto -> AuthId -> TransactionM sto ()
|
||||
|
||||
-- | Insert a new session. Throws 'SessionAlreadyExists' if
|
||||
-- there already exists a session with the same session ID (we
|
||||
-- only call this method after generating a fresh session ID).
|
||||
insertSession :: s -> Session -> TransactionM s ()
|
||||
insertSession :: sto -> Session (SessionData sto) -> TransactionM sto ()
|
||||
|
||||
-- | Replace the contents of a session. Throws
|
||||
-- 'SessionDoesNotExist' if there is no session with the given
|
||||
@ -211,29 +337,32 @@ class MonadIO (TransactionM s) => Storage s where
|
||||
-- Most of the time this discussion does not matter.
|
||||
-- Invalidations usually occur at times where only one request
|
||||
-- is flying.
|
||||
replaceSession :: s -> Session -> TransactionM s ()
|
||||
replaceSession :: sto -> Session (SessionData sto) -> TransactionM sto ()
|
||||
|
||||
|
||||
-- | Common exceptions that may be thrown by any storage.
|
||||
data StorageException =
|
||||
data StorageException sto =
|
||||
-- | Exception thrown by 'insertSession' whenever a session
|
||||
-- with same ID already exists.
|
||||
SessionAlreadyExists
|
||||
{ seExistingSession :: Session
|
||||
, seNewSession :: Session }
|
||||
{ seExistingSession :: Session (SessionData sto)
|
||||
, seNewSession :: Session (SessionData sto) }
|
||||
-- | Exception thrown by 'replaceSession' whenever trying to
|
||||
-- replace a session that is not present on the storage.
|
||||
| SessionDoesNotExist
|
||||
{ seNewSession :: Session }
|
||||
deriving (Show, Typeable)
|
||||
{ seNewSession :: Session (SessionData sto) }
|
||||
deriving (Typeable)
|
||||
|
||||
instance E.Exception StorageException where
|
||||
deriving instance Eq (Decomposed (SessionData sto)) => Eq (StorageException sto)
|
||||
deriving instance Ord (Decomposed (SessionData sto)) => Ord (StorageException sto)
|
||||
deriving instance Show (Decomposed (SessionData sto)) => Show (StorageException sto)
|
||||
|
||||
instance Storage sto => E.Exception (StorageException sto) where
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
-- TODO: delete expired sessions.
|
||||
|
||||
-- | The server-side session backend needs to maintain some state
|
||||
@ -256,10 +385,10 @@ instance E.Exception StorageException where
|
||||
-- and/or secure ('setSecureCookies').
|
||||
--
|
||||
-- Create a new 'State' using 'createState'.
|
||||
data State s =
|
||||
data State sto =
|
||||
State
|
||||
{ generator :: !N.Generator
|
||||
, storage :: !s
|
||||
, storage :: !sto
|
||||
, cookieName :: !Text
|
||||
, authKey :: !Text
|
||||
, idleTimeout :: !(Maybe NominalDiffTime)
|
||||
@ -273,7 +402,7 @@ data State s =
|
||||
|
||||
-- | Create a new 'State' for the server-side session backend
|
||||
-- using the given storage backend.
|
||||
createState :: MonadIO m => s -> m (State s)
|
||||
createState :: MonadIO m => sto -> m (State sto)
|
||||
createState sto = do
|
||||
gen <- N.new
|
||||
return State
|
||||
@ -294,13 +423,21 @@ createState sto = do
|
||||
-- Defaults to \"JSESSIONID\", which is a generic cookie name
|
||||
-- used by many frameworks thus making it harder to fingerprint
|
||||
-- this implementation.
|
||||
setCookieName :: Text -> State s -> State s
|
||||
setCookieName :: Text -> State sto -> State sto
|
||||
setCookieName val state = state { cookieName = val }
|
||||
|
||||
|
||||
-- | Set the name of the session variable that keeps track of the
|
||||
-- logged user. Defaults to \"_ID\" (used by @yesod-auth@).
|
||||
setAuthKey :: Text -> State s -> State s
|
||||
-- logged user.
|
||||
--
|
||||
-- This setting is used by session data types that are
|
||||
-- @Map@-alike, using a @lookup@ function. However, the
|
||||
-- 'IsSessionData' instance of a session data type may choose not
|
||||
-- to use it. For example, if you implemented a custom data
|
||||
-- type, you could return the @AuthId@ without needing a lookup.
|
||||
--
|
||||
-- Defaults to \"_ID\" (used by @yesod-auth@).
|
||||
setAuthKey :: Text -> State sto -> State sto
|
||||
setAuthKey val state = state { authKey = val }
|
||||
|
||||
|
||||
@ -318,7 +455,7 @@ setAuthKey val state = state { authKey = val }
|
||||
-- (<https://www.owasp.org/index.php/Session_Management_Cheat_Sheet#Idle_Timeout Source>)
|
||||
--
|
||||
-- Defaults to 7 days.
|
||||
setIdleTimeout :: Maybe NominalDiffTime -> State s -> State s
|
||||
setIdleTimeout :: Maybe NominalDiffTime -> State sto -> State sto
|
||||
setIdleTimeout (Just d) _ | d <= 0 = error "serversession/setIdleTimeout: Timeout should be positive."
|
||||
setIdleTimeout val state = state { idleTimeout = val }
|
||||
|
||||
@ -338,7 +475,7 @@ setIdleTimeout val state = state { idleTimeout = val }
|
||||
-- (<https://www.owasp.org/index.php/Session_Management_Cheat_Sheet#Absolute_Timeout Source>)
|
||||
--
|
||||
-- Defaults to 60 days.
|
||||
setAbsoluteTimeout :: Maybe NominalDiffTime -> State s -> State s
|
||||
setAbsoluteTimeout :: Maybe NominalDiffTime -> State sto -> State sto
|
||||
setAbsoluteTimeout (Just d) _ | d <= 0 = error "serversession/setAbsoluteTimeout: Timeout should be positive."
|
||||
setAbsoluteTimeout val state = state { absoluteTimeout = val }
|
||||
|
||||
@ -368,7 +505,7 @@ setAbsoluteTimeout val state = state { absoluteTimeout = val }
|
||||
-- becomes disabled and the session will always be updated.
|
||||
--
|
||||
-- Defaults to 10 minutes.
|
||||
setTimeoutResolution :: Maybe NominalDiffTime -> State s -> State s
|
||||
setTimeoutResolution :: Maybe NominalDiffTime -> State sto -> State sto
|
||||
setTimeoutResolution (Just d) _ | d <= 0 = error "serversession/setTimeoutResolution: Resolution should be positive."
|
||||
setTimeoutResolution val state = state { timeoutResolution = val }
|
||||
|
||||
@ -382,7 +519,7 @@ setTimeoutResolution val state = state { timeoutResolution = val }
|
||||
-- cookie is set to expire in 10 years.
|
||||
--
|
||||
-- Defaults to @True@.
|
||||
setPersistentCookies :: Bool -> State s -> State s
|
||||
setPersistentCookies :: Bool -> State sto -> State sto
|
||||
setPersistentCookies val state = state { persistentCookies = val }
|
||||
|
||||
|
||||
@ -393,7 +530,7 @@ setPersistentCookies val state = state { persistentCookies = val }
|
||||
-- It's highly recommended to set this attribute to @True@.
|
||||
--
|
||||
-- Defaults to @True@.
|
||||
setHttpOnlyCookies :: Bool -> State s -> State s
|
||||
setHttpOnlyCookies :: Bool -> State sto -> State sto
|
||||
setHttpOnlyCookies val state = state { httpOnlyCookies = val }
|
||||
|
||||
|
||||
@ -405,22 +542,22 @@ setHttpOnlyCookies val state = state { httpOnlyCookies = val }
|
||||
-- @False@.
|
||||
--
|
||||
-- Defaults to @False@.
|
||||
setSecureCookies :: Bool -> State s -> State s
|
||||
setSecureCookies :: Bool -> State sto -> State sto
|
||||
setSecureCookies val state = state { secureCookies = val }
|
||||
|
||||
|
||||
-- | Cf. 'setCookieName'.
|
||||
getCookieName :: State s -> Text
|
||||
getCookieName :: State sto -> Text
|
||||
getCookieName = cookieName
|
||||
|
||||
|
||||
-- | Cf. 'setHttpOnlyCookies'.
|
||||
getHttpOnlyCookies :: State s -> Bool
|
||||
getHttpOnlyCookies :: State sto -> Bool
|
||||
getHttpOnlyCookies = httpOnlyCookies
|
||||
|
||||
|
||||
-- | Cf. 'setSecureCookies'.
|
||||
getSecureCookies :: State s -> Bool
|
||||
getSecureCookies :: State sto -> Bool
|
||||
getSecureCookies = secureCookies
|
||||
|
||||
|
||||
@ -432,25 +569,33 @@ getSecureCookies = secureCookies
|
||||
--
|
||||
-- Returns:
|
||||
--
|
||||
-- * The 'SessionMap' to be used by the frontend as the current
|
||||
-- session's value.
|
||||
-- * The session data @sess@ to be used by the frontend as the
|
||||
-- current session's value.
|
||||
--
|
||||
-- * Information to be passed back to 'saveSession' on the end
|
||||
-- of the request in order to save the session.
|
||||
loadSession :: Storage s => State s -> Maybe ByteString -> IO (SessionMap, SaveSessionToken)
|
||||
loadSession
|
||||
:: Storage sto
|
||||
=> State sto
|
||||
-> Maybe ByteString
|
||||
-> IO (SessionData sto, SaveSessionToken sto)
|
||||
loadSession state mcookieVal = do
|
||||
now <- getCurrentTime
|
||||
let maybeInputId = mcookieVal >>= fromPathPiece . TE.decodeUtf8
|
||||
get = runTransactionM (storage state) . getSession (storage state)
|
||||
checkedGet = fmap (>>= checkExpired now state) . get
|
||||
maybeInput <- maybe (return Nothing) checkedGet maybeInputId
|
||||
let inputSessionMap = maybe M.empty (toSessionMap state) maybeInput
|
||||
return (inputSessionMap, SaveSessionToken maybeInput now)
|
||||
let inputData =
|
||||
maybe
|
||||
emptySession
|
||||
(\s -> recomposeSession (authKey state) (sessionAuthId s) (sessionData s))
|
||||
maybeInput
|
||||
return (inputData, SaveSessionToken maybeInput now)
|
||||
|
||||
|
||||
-- | Check if a session @s@ has expired. Returns the @Just s@ if
|
||||
-- not expired, or @Nothing@ if expired.
|
||||
checkExpired :: UTCTime {-^ Now. -} -> State s -> Session -> Maybe Session
|
||||
checkExpired :: UTCTime {-^ Now. -} -> State sto -> Session sess -> Maybe (Session sess)
|
||||
checkExpired now state session =
|
||||
let expired = maybe False (< now) (nextExpires state session)
|
||||
in guard (not expired) >> return session
|
||||
@ -460,7 +605,7 @@ checkExpired now state session =
|
||||
-- will expire assuming that it sees no activity until then.
|
||||
-- Returns @Nothing@ iff the state does not have any expirations
|
||||
-- set to @Just@.
|
||||
nextExpires :: State s -> Session -> Maybe UTCTime
|
||||
nextExpires :: State sto -> Session sess -> Maybe UTCTime
|
||||
nextExpires State {..} Session {..} =
|
||||
let viaIdle = flip addUTCTime sessionAccessedAt <$> idleTimeout
|
||||
viaAbsolute = flip addUTCTime sessionCreatedAt <$> absoluteTimeout
|
||||
@ -471,7 +616,7 @@ nextExpires State {..} Session {..} =
|
||||
|
||||
-- | Calculate the date that should be used for the cookie's
|
||||
-- \"Expires\" field.
|
||||
cookieExpires :: State s -> Session -> Maybe UTCTime
|
||||
cookieExpires :: State sto -> Session sess -> Maybe UTCTime
|
||||
cookieExpires State {..} _ | not persistentCookies = Nothing
|
||||
cookieExpires state session =
|
||||
Just $ fromMaybe tenYearsFromNow $ nextExpires state session
|
||||
@ -481,9 +626,13 @@ cookieExpires state session =
|
||||
|
||||
-- | Opaque token containing the necessary information for
|
||||
-- 'saveSession' to save the session.
|
||||
data SaveSessionToken =
|
||||
SaveSessionToken (Maybe Session) UTCTime
|
||||
deriving (Eq, Show, Typeable)
|
||||
data SaveSessionToken sto =
|
||||
SaveSessionToken (Maybe (Session (SessionData sto))) UTCTime
|
||||
deriving (Typeable)
|
||||
|
||||
deriving instance Eq (Decomposed (SessionData sto)) => Eq (SaveSessionToken sto)
|
||||
deriving instance Ord (Decomposed (SessionData sto)) => Ord (SaveSessionToken sto)
|
||||
deriving instance Show (Decomposed (SessionData sto)) => Show (SaveSessionToken sto)
|
||||
|
||||
|
||||
-- | Save the session on the storage backend. A
|
||||
@ -496,12 +645,17 @@ data SaveSessionToken =
|
||||
-- and clear every other sesssion variable, then 'saveSession'
|
||||
-- will invalidate the older session but will avoid creating a
|
||||
-- new, empty one.
|
||||
saveSession :: Storage s => State s -> SaveSessionToken -> SessionMap -> IO (Maybe Session)
|
||||
saveSession state (SaveSessionToken maybeInput now) wholeOutputSessionMap =
|
||||
saveSession
|
||||
:: Storage sto
|
||||
=> State sto
|
||||
-> SaveSessionToken sto
|
||||
-> SessionData sto
|
||||
-> IO (Maybe (Session (SessionData sto)))
|
||||
saveSession state (SaveSessionToken maybeInput now) outputData =
|
||||
runTransactionM (storage state) $ do
|
||||
let decomposedSessionMap = decomposeSession state wholeOutputSessionMap
|
||||
newMaybeInput <- invalidateIfNeeded state maybeInput decomposedSessionMap
|
||||
saveSessionOnDb state now newMaybeInput decomposedSessionMap
|
||||
let outputDecomp = decomposeSession (authKey state) outputData
|
||||
newMaybeInput <- invalidateIfNeeded state maybeInput outputDecomp
|
||||
saveSessionOnDb state now newMaybeInput outputDecomp
|
||||
|
||||
|
||||
-- | Invalidates an old session ID if needed. Returns the
|
||||
@ -512,11 +666,11 @@ saveSession state (SaveSessionToken maybeInput now) wholeOutputSessionMap =
|
||||
-- fixation attacks. We also invalidate when asked to via
|
||||
-- 'forceInvalidate'.
|
||||
invalidateIfNeeded
|
||||
:: Storage s
|
||||
=> State s
|
||||
-> Maybe Session
|
||||
-> DecomposedSession
|
||||
-> TransactionM s (Maybe Session)
|
||||
:: Storage sto
|
||||
=> State sto
|
||||
-> Maybe (Session (SessionData sto))
|
||||
-> DecomposedSession (SessionData sto)
|
||||
-> TransactionM sto (Maybe (Session (SessionData sto)))
|
||||
invalidateIfNeeded state maybeInput DecomposedSession {..} = do
|
||||
-- Decide which action to take.
|
||||
-- "invalidateOthers implies invalidateCurrent" should be true below.
|
||||
@ -531,26 +685,6 @@ invalidateIfNeeded state maybeInput DecomposedSession {..} = do
|
||||
return $ guard (not invalidateCurrent) >> maybeInput
|
||||
|
||||
|
||||
-- | A 'SessionMap' with its special variables taken apart.
|
||||
data DecomposedSession =
|
||||
DecomposedSession
|
||||
{ dsAuthId :: !(Maybe ByteString)
|
||||
, dsForceInvalidate :: !ForceInvalidate
|
||||
, dsSessionMap :: !SessionMap
|
||||
} deriving (Eq, Show, Typeable)
|
||||
|
||||
|
||||
-- | Decompose a session (see 'DecomposedSession').
|
||||
decomposeSession :: State s -> SessionMap -> DecomposedSession
|
||||
decomposeSession state sm1 =
|
||||
let (authId, sm2) = M.updateLookupWithKey (\_ _ -> Nothing) (authKey state) sm1
|
||||
(force, sm3) = M.updateLookupWithKey (\_ _ -> Nothing) forceInvalidateKey sm2
|
||||
in DecomposedSession
|
||||
{ dsAuthId = authId
|
||||
, dsForceInvalidate = maybe DoNotForceInvalidate (read . B8.unpack) force
|
||||
, dsSessionMap = sm3 }
|
||||
|
||||
|
||||
-- | Save a session on the database. If an old session is
|
||||
-- supplied, it is replaced, otherwise a new session is
|
||||
-- generated. If the session is empty, it is not saved and
|
||||
@ -558,24 +692,30 @@ decomposeSession state sm1 =
|
||||
-- is applied (cf. 'setTimeoutResolution'), the old session is
|
||||
-- returned and no update is made.
|
||||
saveSessionOnDb
|
||||
:: Storage s
|
||||
=> State s
|
||||
-> UTCTime -- ^ Now.
|
||||
-> Maybe Session -- ^ The old session, if any.
|
||||
-> DecomposedSession -- ^ The session data to be saved.
|
||||
-> TransactionM s (Maybe Session) -- ^ Copy of saved session.
|
||||
:: forall sto. Storage sto
|
||||
=> State sto
|
||||
-> UTCTime -- ^ Now.
|
||||
-> Maybe (Session (SessionData sto)) -- ^ The old session, if any.
|
||||
-> DecomposedSession (SessionData sto) -- ^ The session data to be saved.
|
||||
-> TransactionM sto (Maybe (Session (SessionData sto))) -- ^ Copy of saved session.
|
||||
saveSessionOnDb _ _ Nothing (DecomposedSession Nothing _ m)
|
||||
-- Return Nothing without doing anything whenever the session
|
||||
-- is empty (including auth ID) and there was no prior session.
|
||||
| M.null m = return Nothing
|
||||
saveSessionOnDb State { timeoutResolution = Just res } now (Just old) (DecomposedSession authId _ sessionMap)
|
||||
| isDecomposedEmpty proxy m = return Nothing
|
||||
where
|
||||
proxy :: Maybe (SessionData sto)
|
||||
proxy = Nothing
|
||||
saveSessionOnDb State { timeoutResolution = Just res } now (Just old) (DecomposedSession authId _ newSession)
|
||||
-- If the data is the same and the old access time is within
|
||||
-- the timeout resolution, just return the old session without
|
||||
-- doing anything else.
|
||||
| sessionData old == sessionMap &&
|
||||
sessionAuthId old == authId &&
|
||||
| sessionAuthId old == authId &&
|
||||
isSameDecomposed proxy (sessionData old) newSession &&
|
||||
abs (diffUTCTime now (sessionAccessedAt old)) < res =
|
||||
return (Just old)
|
||||
where
|
||||
proxy :: Maybe (SessionData sto)
|
||||
proxy = Nothing
|
||||
saveSessionOnDb state now maybeInput DecomposedSession {..} = do
|
||||
-- Generate properties if needed or take them from previous
|
||||
-- saved session.
|
||||
@ -593,7 +733,7 @@ saveSessionOnDb state now maybeInput DecomposedSession {..} = do
|
||||
let session = Session
|
||||
{ sessionKey = key
|
||||
, sessionAuthId = dsAuthId
|
||||
, sessionData = dsSessionMap
|
||||
, sessionData = dsDecomposed
|
||||
, sessionCreatedAt = createdAt
|
||||
, sessionAccessedAt = now
|
||||
}
|
||||
@ -601,12 +741,6 @@ saveSessionOnDb state now maybeInput DecomposedSession {..} = do
|
||||
return (Just session)
|
||||
|
||||
|
||||
-- | Create a 'SessionMap' from a 'Session'.
|
||||
toSessionMap :: State s -> Session -> SessionMap
|
||||
toSessionMap state Session {..} =
|
||||
maybe id (M.insert $ authKey state) sessionAuthId sessionData
|
||||
|
||||
|
||||
-- | The session key used to signal that the session ID should be
|
||||
-- invalidated.
|
||||
forceInvalidateKey :: Text
|
||||
|
||||
@ -18,7 +18,7 @@ import qualified Data.Text as T
|
||||
import qualified Data.Time as TI
|
||||
|
||||
|
||||
-- | Execute all storage tests.
|
||||
-- | Execute all storage tests using 'SessionMap'.
|
||||
--
|
||||
-- This function is meant to be used with @hspec@. However, we
|
||||
-- don't want to depend on @hspec@, so it takes the relevant
|
||||
@ -42,8 +42,8 @@ import qualified Data.Time as TI
|
||||
-- sequentially in order to reduce the peak memory usage of the
|
||||
-- test suite.
|
||||
allStorageTests
|
||||
:: forall m s. (Monad m, Storage s)
|
||||
=> s -- ^ Storage backend.
|
||||
:: forall m sto. (Monad m, Storage sto, SessionData sto ~ SessionMap)
|
||||
=> sto -- ^ Storage backend.
|
||||
-> (String -> IO () -> m ()) -- ^ @hspec@'s it.
|
||||
-> (forall a. IO a -> m a) -- ^ @hspec@'s runIO.
|
||||
-> (m () -> m ()) -- ^ @hspec@'s parallel
|
||||
@ -52,7 +52,7 @@ allStorageTests
|
||||
-> (forall a e. Exception e => IO a -> (e -> Bool) -> IO ()) -- ^ @hspec@'s shouldThrow.
|
||||
-> m ()
|
||||
allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = do
|
||||
let run :: forall a. TransactionM s a -> IO a
|
||||
let run :: forall a. TransactionM sto a -> IO a
|
||||
run = runTransactionM storage
|
||||
|
||||
gen <- runIO N.new
|
||||
@ -131,7 +131,8 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d
|
||||
run (insertSession storage s1)
|
||||
run (getSession storage sid) `shouldReturn` Just s1
|
||||
run (insertSession storage s3) `shouldThrow`
|
||||
(\(SessionAlreadyExists s1' s3') -> s1 == s1' && s3 == s3')
|
||||
(\(SessionAlreadyExists s1' s3' :: StorageException sto) ->
|
||||
s1 == s1' && s3 == s3')
|
||||
run (getSession storage sid) `shouldReturn` Just s1
|
||||
|
||||
-- replaceSession
|
||||
@ -153,7 +154,8 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d
|
||||
s <- generateSession gen HasAuthId
|
||||
let sid = sessionKey s
|
||||
run (getSession storage sid) `shouldReturn` Nothing
|
||||
run (replaceSession storage s) `shouldThrow` (\(SessionDoesNotExist s') -> s == s')
|
||||
run (replaceSession storage s) `shouldThrow`
|
||||
(\(SessionDoesNotExist s' :: StorageException sto) -> s == s')
|
||||
run (getSession storage sid) `shouldReturn` Nothing
|
||||
run (insertSession storage s)
|
||||
run (getSession storage sid) `shouldReturn` Just s
|
||||
@ -169,11 +171,11 @@ allStorageTests storage it runIO parallel _shouldBe shouldReturn shouldThrow = d
|
||||
let session = Session
|
||||
{ sessionKey = sid
|
||||
, sessionAuthId = Nothing
|
||||
, sessionData = M.fromList vals
|
||||
, sessionData = SessionMap $ M.fromList vals
|
||||
, sessionCreatedAt = now
|
||||
, sessionAccessedAt = now
|
||||
}
|
||||
ver2 = session { sessionData = M.empty }
|
||||
ver2 = session { sessionData = SessionMap M.empty }
|
||||
run (getSession storage sid) `shouldReturn` Nothing
|
||||
run (insertSession storage session)
|
||||
run (getSession storage sid) `shouldReturn` (Just session)
|
||||
@ -204,7 +206,7 @@ generateAuthId = N.nonce128url
|
||||
|
||||
|
||||
-- | Generate a random session for our tests.
|
||||
generateSession :: N.Generator -> HasAuthId -> IO Session
|
||||
generateSession :: N.Generator -> HasAuthId -> IO (Session SessionMap)
|
||||
generateSession gen hasAuthId = do
|
||||
sid <- generateSessionId gen
|
||||
authId <-
|
||||
@ -219,7 +221,7 @@ generateSession gen hasAuthId = do
|
||||
return Session
|
||||
{ sessionKey = sid
|
||||
, sessionAuthId = authId
|
||||
, sessionData = data_
|
||||
, sessionData = SessionMap data_
|
||||
, sessionCreatedAt = TI.addUTCTime (-1000) now
|
||||
, sessionAccessedAt = now
|
||||
}
|
||||
|
||||
@ -58,7 +58,7 @@ main = hspec $ parallel $ do
|
||||
return $ fromPathPiece (toPathPiece sid) Q.=== Just sid
|
||||
|
||||
it "does not accept as valid some example invalid session IDs" $ do
|
||||
let parse = fromPathPiece :: T.Text -> Maybe SessionId
|
||||
let parse = fromPathPiece :: T.Text -> Maybe (SessionId SessionMap)
|
||||
parse "" `shouldBe` Nothing
|
||||
parse "123456789-123456789-123" `shouldBe` Nothing
|
||||
parse "123456789-123456789-12345" `shouldBe` Nothing
|
||||
@ -95,7 +95,7 @@ main = hspec $ parallel $ do
|
||||
let point1 = 0.1 {- second -} :: Double
|
||||
now <- TI.getCurrentTime
|
||||
abs (realToFrac $ TI.diffUTCTime now time) `shouldSatisfy` (< point1)
|
||||
sessionMap `shouldBe` M.empty
|
||||
sessionMap `shouldBe` TNTSessionData
|
||||
msession `shouldSatisfy` isNothing
|
||||
|
||||
it "returns empty session and token when the session ID cookie is not present" $ do
|
||||
@ -119,7 +119,7 @@ main = hspec $ parallel $ do
|
||||
let session = Session
|
||||
{ sessionKey = S "123456789-123456789-1234"
|
||||
, sessionAuthId = Just authId
|
||||
, sessionData = M.fromList [("a", "b"), ("c", "d")]
|
||||
, sessionData = mkSessionMap [("a", "b"), ("c", "d")]
|
||||
, sessionCreatedAt = TI.addUTCTime (-10) fakenow
|
||||
, sessionAccessedAt = TI.addUTCTime (-5) fakenow
|
||||
}
|
||||
@ -127,7 +127,7 @@ main = hspec $ parallel $ do
|
||||
st <- createState =<< prepareMockStorage [session]
|
||||
(retSessionMap, SaveSessionToken msession _now) <-
|
||||
loadSession st (Just $ B8.pack $ T.unpack $ unS $ sessionKey session)
|
||||
retSessionMap `shouldBe` M.insert (authKey st) authId (sessionData session)
|
||||
retSessionMap `shouldBe` onSM (M.insert (authKey st) authId) (sessionData session)
|
||||
msession `shouldBe` Just session
|
||||
|
||||
describe "checkExpired" $ do
|
||||
@ -214,31 +214,31 @@ main = hspec $ parallel $ do
|
||||
it "works for a complex example" $ do
|
||||
sto <- emptyMockStorage
|
||||
st <- createState sto
|
||||
saveSession st (SaveSessionToken Nothing fakenow) M.empty `shouldReturn` Nothing
|
||||
saveSession st (SaveSessionToken Nothing fakenow) emptySM `shouldReturn` Nothing
|
||||
getMockOperations sto `shouldReturn` []
|
||||
|
||||
let m1 = M.singleton "a" "b"
|
||||
let m1 = mkSessionMap [("a", "b")]
|
||||
Just session1 <- saveSession st (SaveSessionToken Nothing fakenow) m1
|
||||
sessionAuthId session1 `shouldBe` Nothing
|
||||
sessionData session1 `shouldBe` m1
|
||||
getMockOperations sto `shouldReturn` [InsertSession session1]
|
||||
|
||||
let m2 = M.insert (authKey st) "john" m1
|
||||
let m2 = onSM (M.insert (authKey st) "john") m1
|
||||
Just session2 <- saveSession st (SaveSessionToken (Just session1) fakenow) m2
|
||||
sessionAuthId session2 `shouldBe` Just "john"
|
||||
sessionData session2 `shouldBe` m1
|
||||
sessionKey session2 == sessionKey session1 `shouldBe` False
|
||||
getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session1), InsertSession session2]
|
||||
|
||||
let m3 = M.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser) m2
|
||||
let m3 = onSM (M.insert forceInvalidateKey (B8.pack $ show AllSessionIdsOfLoggedUser)) m2
|
||||
Just session3 <- saveSession st (SaveSessionToken (Just session2) fakenow) m3
|
||||
session3 `shouldBe` session2 { sessionKey = sessionKey session3 }
|
||||
getMockOperations sto `shouldReturn`
|
||||
[DeleteSession (sessionKey session2), DeleteAllSessionsOfAuthId "john", InsertSession session3]
|
||||
|
||||
let m4 = M.insert "x" "y" m2
|
||||
let m4 = onSM (M.insert "x" "y") m2
|
||||
Just session4 <- saveSession st (SaveSessionToken (Just session3) fakenow) m4
|
||||
session4 `shouldBe` session3 { sessionData = M.delete (authKey st) m4 }
|
||||
session4 `shouldBe` session3 { sessionData = onSM (M.delete (authKey st)) m4 }
|
||||
getMockOperations sto `shouldReturn` [ReplaceSession session4]
|
||||
|
||||
Just session5 <- saveSession st (SaveSessionToken (Just session4) (TI.addUTCTime 10 fakenow)) m4
|
||||
@ -250,18 +250,18 @@ main = hspec $ parallel $ do
|
||||
let oldSession = Session
|
||||
{ sessionKey = S "123456789-123456789-1234"
|
||||
, sessionAuthId = authId
|
||||
, sessionData = M.empty
|
||||
, sessionData = emptySM
|
||||
, sessionCreatedAt = TI.addUTCTime (-10) fakenow
|
||||
, sessionAccessedAt = TI.addUTCTime (-5) fakenow }
|
||||
sto <- prepareMockStorage [oldSession]
|
||||
st <- createState sto
|
||||
return (oldSession, sto, st)
|
||||
return (oldSession, sto :: MockStorage SessionMap, st)
|
||||
allEdges = let x = [Nothing, Just "john", Just "jane"] in (,) <$> x <*> x
|
||||
|
||||
it "does not invalidate when not changing auth ID nor explicitly requesting" $ do
|
||||
forM_ [Nothing, Just "john"] $ \authId -> do
|
||||
(session, sto, st) <- prepareInvalidateIfNeeded authId
|
||||
let d = DecomposedSession authId DoNotForceInvalidate M.empty
|
||||
let d = DecomposedSession authId DoNotForceInvalidate emptySM
|
||||
invalidateIfNeeded st (Just session) d `shouldReturn` Just session
|
||||
getMockOperations sto `shouldReturn` []
|
||||
|
||||
@ -270,21 +270,21 @@ main = hspec $ parallel $ do
|
||||
, (Just "admin", Nothing)
|
||||
, (Nothing, Just "joe") ] $ \edgeTransition -> do
|
||||
(session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition)
|
||||
let d = DecomposedSession (snd edgeTransition) DoNotForceInvalidate M.empty
|
||||
let d = DecomposedSession (snd edgeTransition) DoNotForceInvalidate emptySM
|
||||
invalidateIfNeeded st (Just session) d `shouldReturn` Nothing
|
||||
getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session)]
|
||||
|
||||
it "invalidates the current session when CurrentSessionId is forced" $ do
|
||||
forM_ allEdges $ \edgeTransition -> do
|
||||
(session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition)
|
||||
let d = DecomposedSession (snd edgeTransition) CurrentSessionId M.empty
|
||||
let d = DecomposedSession (snd edgeTransition) CurrentSessionId emptySM
|
||||
invalidateIfNeeded st (Just session) d `shouldReturn` Nothing
|
||||
getMockOperations sto `shouldReturn` [DeleteSession (sessionKey session)]
|
||||
|
||||
it "invalidates all of the user's sessions when AllSessionIdsOfLoggedUser is forced" $ do
|
||||
forM_ allEdges $ \edgeTransition -> do
|
||||
(session, sto, st) <- prepareInvalidateIfNeeded (fst edgeTransition)
|
||||
let d = DecomposedSession (snd edgeTransition) AllSessionIdsOfLoggedUser M.empty
|
||||
let d = DecomposedSession (snd edgeTransition) AllSessionIdsOfLoggedUser emptySM
|
||||
invalidateIfNeeded st (Just session) d `shouldReturn` Nothing
|
||||
let expected = DeleteSession (sessionKey session) :
|
||||
maybe [] ((:[]) . DeleteAllSessionsOfAuthId) (snd edgeTransition)
|
||||
@ -296,19 +296,19 @@ main = hspec $ parallel $ do
|
||||
let oldSession = Session
|
||||
{ sessionKey = S "123456789-123456789-1234"
|
||||
, sessionAuthId = Just "auth"
|
||||
, sessionData = M.fromList [("a", "b"), ("c", "d")]
|
||||
, sessionData = mkSessionMap [("a", "b"), ("c", "d")]
|
||||
, sessionCreatedAt = TI.addUTCTime (-10) fakenow
|
||||
, sessionAccessedAt = TI.addUTCTime (-5) fakenow }
|
||||
sto <- prepareMockStorage [oldSession]
|
||||
st <- createState sto
|
||||
return (oldSession, sto, st)
|
||||
emptyDecomp = DecomposedSession Nothing DoNotForceInvalidate M.empty
|
||||
return (oldSession, sto :: MockStorage SessionMap, st)
|
||||
emptyDecomp = DecomposedSession Nothing DoNotForceInvalidate emptySM
|
||||
|
||||
it "inserts new sessions when there wasn't an old one" $ do
|
||||
sto <- emptyMockStorage
|
||||
st <- createState sto
|
||||
st <- createState (sto :: MockStorage SessionMap)
|
||||
let d = DecomposedSession a DoNotForceInvalidate m
|
||||
m = M.fromList [("a", "b"), ("c", "d")]
|
||||
m = mkSessionMap [("a", "b"), ("c", "d")]
|
||||
a = Just "auth"
|
||||
Just session <- saveSessionOnDb st fakenow Nothing d
|
||||
getMockOperations sto `shouldReturn` [InsertSession session]
|
||||
@ -320,7 +320,7 @@ main = hspec $ parallel $ do
|
||||
it "replaces sesssions when there was an old one" $ do
|
||||
(oldSession, sto, st) <- prepareSaveSessionOnDb
|
||||
let d = DecomposedSession Nothing DoNotForceInvalidate m
|
||||
m = M.fromList [("a", "b"), ("x", "y")]
|
||||
m = mkSessionMap [("a", "b"), ("x", "y")]
|
||||
Just session <- saveSessionOnDb st fakenow (Just oldSession) d
|
||||
getMockOperations sto `shouldReturn` [ReplaceSession session]
|
||||
session `shouldBe` oldSession
|
||||
@ -336,7 +336,7 @@ main = hspec $ parallel $ do
|
||||
|
||||
it "saves session if it's empty but there was an old one" $ do
|
||||
(oldSession, sto, st) <- prepareSaveSessionOnDb
|
||||
let newSession = oldSession { sessionData = M.empty
|
||||
let newSession = oldSession { sessionData = emptySM
|
||||
, sessionAuthId = Nothing
|
||||
, sessionAccessedAt = fakenow }
|
||||
saveSessionOnDb st fakenow (Just oldSession) emptyDecomp `shouldReturn` Just newSession
|
||||
@ -356,46 +356,42 @@ main = hspec $ parallel $ do
|
||||
saveSessionOnDb st (t 1) (Just session1) d `shouldReturn` Just session2
|
||||
getMockOperations sto `shouldReturn` [ReplaceSession session2]
|
||||
|
||||
describe "decomposeSession" $ do
|
||||
describe "decomposeSession/SessionMap" $ do
|
||||
let authKey_ = authKey stnull
|
||||
|
||||
prop "it is sane when not finding auth key or force invalidate key" $
|
||||
\data_ ->
|
||||
let sessionMap = mkSessionMap $ filter (notSpecial . fst) $ data_
|
||||
notSpecial = flip notElem [authKey stnull, forceInvalidateKey] . T.pack
|
||||
in decomposeSession stnull sessionMap `shouldBe`
|
||||
in decomposeSession authKey_ sessionMap `shouldBe`
|
||||
DecomposedSession Nothing DoNotForceInvalidate sessionMap
|
||||
|
||||
prop "parses the force invalidate key" $
|
||||
\data_ ->
|
||||
let sessionMap v = M.insert forceInvalidateKey (B8.pack $ show v) $ mkSessionMap data_
|
||||
let sessionMap v = onSM (M.insert forceInvalidateKey (B8.pack $ show v)) $ mkSessionMap data_
|
||||
allForces = [minBound..maxBound] :: [ForceInvalidate]
|
||||
test v = dsForceInvalidate (decomposeSession stnull $ sessionMap v) Q.=== v
|
||||
test v = dsForceInvalidate (decomposeSession authKey_ $ sessionMap v) Q.=== v
|
||||
in Q.conjoin (test <$> allForces)
|
||||
|
||||
it "removes the auth key" $ do
|
||||
let m = M.singleton "a" "b"; m' = M.insert (authKey stnull) "x" m
|
||||
decomposeSession stnull m' `shouldBe`
|
||||
DecomposedSession (Just "x") DoNotForceInvalidate m
|
||||
decomposeSession authKey_ (SessionMap m') `shouldBe`
|
||||
DecomposedSession (Just "x") DoNotForceInvalidate (SessionMap m)
|
||||
|
||||
describe "toSessionMap" $ do
|
||||
let mkSession authId data_ = Session
|
||||
{ sessionKey = error "irrelevant 1"
|
||||
, sessionAuthId = authId
|
||||
, sessionData = mkSessionMap data_
|
||||
, sessionCreatedAt = error "irrelevant 2"
|
||||
, sessionAccessedAt = error "irrelevant 3"
|
||||
}
|
||||
describe "recomposeSession/SessionMap" $ do
|
||||
let authKey_ = authKey stnull
|
||||
|
||||
prop "does not change session data for sessions without auth ID" $
|
||||
\data_ ->
|
||||
let s = mkSession Nothing data_
|
||||
in toSessionMap stnull s Q.=== sessionData s
|
||||
let s = mkSessionMap data_
|
||||
in recomposeSession authKey_ Nothing s Q.=== s
|
||||
|
||||
prop "adds (overwriting) the auth ID to the session data" $
|
||||
\authId_ data_ ->
|
||||
let s = mkSession (Just authId) ((T.unpack k, "foo") : data_)
|
||||
k = authKey stnull
|
||||
let s = mkSessionMap ((T.unpack authKey_, "foo") : data_)
|
||||
authId = B8.pack authId_
|
||||
in toSessionMap stnull s Q.=== M.adjust (const authId) k (sessionData s)
|
||||
in recomposeSession authKey_ (Just authId) s
|
||||
Q.=== onSM (M.adjust (const authId) authKey_) s
|
||||
|
||||
describe "MockStorage" $ do
|
||||
sto <- runIO emptyMockStorage
|
||||
@ -404,7 +400,19 @@ main = hspec $ parallel $ do
|
||||
|
||||
-- | Used to generate session maps on QuickCheck properties.
|
||||
mkSessionMap :: [(String, String)] -> SessionMap
|
||||
mkSessionMap = M.fromList . map (T.pack *** B8.pack)
|
||||
mkSessionMap = SessionMap . M.fromList . map (T.pack *** B8.pack)
|
||||
|
||||
|
||||
-- | Apply a function to a 'SessionMap'.
|
||||
onSM
|
||||
:: (M.Map T.Text B8.ByteString -> M.Map T.Text B8.ByteString)
|
||||
-> (SessionMap -> SessionMap)
|
||||
onSM f = SessionMap . f . unSessionMap
|
||||
|
||||
|
||||
-- | Empty 'SessionMap'.
|
||||
emptySM :: SessionMap
|
||||
emptySM = emptySession
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
@ -416,6 +424,7 @@ data TNTStorage = TNTStorage deriving (Typeable)
|
||||
|
||||
instance Storage TNTStorage where
|
||||
type TransactionM TNTStorage = IO
|
||||
type SessionData TNTStorage = TNTSessionData
|
||||
runTransactionM _ = id
|
||||
getSession = explode "getSession"
|
||||
deleteSession = explode "deleteSession"
|
||||
@ -436,29 +445,52 @@ data TNTExplosion = TNTExplosion String String deriving (Show, Typeable)
|
||||
instance E.Exception TNTExplosion where
|
||||
|
||||
|
||||
-- | Session data that explodes if it's used. Doesn't explode on
|
||||
-- 'emptySession'.
|
||||
data TNTSessionData = TNTSessionData deriving (Eq, Show, Typeable)
|
||||
|
||||
instance IsSessionData TNTSessionData where
|
||||
type Decomposed TNTSessionData = ()
|
||||
emptySession = TNTSessionData
|
||||
isSameDecomposed _ = curry (explodeD "isSameDecomposed")
|
||||
decomposeSession = curry (explodeD "decomposeSession")
|
||||
recomposeSession = (curry . curry) (explodeD "recomposeSession")
|
||||
isDecomposedEmpty _ = explodeD "isDecomposedEmpty"
|
||||
|
||||
|
||||
-- | Implementation of all 'IsSessionData' methods of
|
||||
-- 'TNTSessionData'.
|
||||
explodeD :: Show a => String -> a -> b
|
||||
explodeD fun = E.throw . TNTExplosion fun . show
|
||||
|
||||
|
||||
----------------------------------------------------------------------
|
||||
|
||||
|
||||
-- | A mock operation that was executed.
|
||||
data MockOperation =
|
||||
GetSession SessionId
|
||||
| DeleteSession SessionId
|
||||
data MockOperation sess =
|
||||
GetSession (SessionId sess)
|
||||
| DeleteSession (SessionId sess)
|
||||
| DeleteAllSessionsOfAuthId AuthId
|
||||
| InsertSession Session
|
||||
| ReplaceSession Session
|
||||
deriving (Eq, Show, Typeable)
|
||||
| InsertSession (Session sess)
|
||||
| ReplaceSession (Session sess)
|
||||
deriving (Typeable)
|
||||
|
||||
deriving instance Eq (Decomposed sess) => Eq (MockOperation sess)
|
||||
deriving instance Show (Decomposed sess) => Show (MockOperation sess)
|
||||
|
||||
|
||||
-- | A mock storage used just for testing.
|
||||
data MockStorage =
|
||||
data MockStorage sess =
|
||||
MockStorage
|
||||
{ mockSessions :: I.IORef (M.Map SessionId Session)
|
||||
, mockOperations :: I.IORef [MockOperation]
|
||||
{ mockSessions :: I.IORef (M.Map (SessionId sess) (Session sess))
|
||||
, mockOperations :: I.IORef [MockOperation sess]
|
||||
}
|
||||
deriving (Typeable)
|
||||
|
||||
instance Storage MockStorage where
|
||||
type TransactionM MockStorage = IO
|
||||
instance IsSessionData sess => Storage (MockStorage sess) where
|
||||
type TransactionM (MockStorage sess) = IO
|
||||
type SessionData (MockStorage sess) = sess
|
||||
runTransactionM _ = id
|
||||
getSession sto sid = do
|
||||
-- We need to use atomicModifyIORef instead of readIORef
|
||||
@ -478,7 +510,7 @@ instance Storage MockStorage where
|
||||
M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap
|
||||
in maybe
|
||||
(newMap, return ())
|
||||
(\oldVal -> (oldMap, E.throwIO $ SessionAlreadyExists oldVal session))
|
||||
(\oldVal -> (oldMap, mockThrow $ SessionAlreadyExists oldVal session))
|
||||
moldVal
|
||||
addMockOperation sto (InsertSession session)
|
||||
replaceSession sto session = do
|
||||
@ -486,14 +518,22 @@ instance Storage MockStorage where
|
||||
let (moldVal, newMap) =
|
||||
M.insertLookupWithKey (\_ v _ -> v) (sessionKey session) session oldMap
|
||||
in maybe
|
||||
(oldMap, E.throwIO $ SessionDoesNotExist session)
|
||||
(oldMap, mockThrow $ SessionDoesNotExist session)
|
||||
(const (newMap, return ()))
|
||||
moldVal
|
||||
addMockOperation sto (ReplaceSession session)
|
||||
|
||||
|
||||
-- | Specialization of 'E.throwIO' for 'MockStorage'.
|
||||
mockThrow
|
||||
:: IsSessionData sess
|
||||
=> StorageException (MockStorage sess)
|
||||
-> TransactionM (MockStorage sess) a
|
||||
mockThrow = E.throwIO
|
||||
|
||||
|
||||
-- | Creates empty mock storage.
|
||||
emptyMockStorage :: IO MockStorage
|
||||
emptyMockStorage :: IO (MockStorage sess)
|
||||
emptyMockStorage =
|
||||
MockStorage
|
||||
<$> I.newIORef M.empty
|
||||
@ -501,7 +541,7 @@ emptyMockStorage =
|
||||
|
||||
|
||||
-- | Creates mock storage with the given sessions already existing.
|
||||
prepareMockStorage :: [Session] -> IO MockStorage
|
||||
prepareMockStorage :: [Session sess] -> IO (MockStorage sess)
|
||||
prepareMockStorage sessions = do
|
||||
sto <- emptyMockStorage
|
||||
I.writeIORef (mockSessions sto) (M.fromList [(sessionKey s, s) | s <- sessions])
|
||||
@ -510,10 +550,10 @@ prepareMockStorage sessions = do
|
||||
|
||||
-- | Get the list of mock operations that were made and clear
|
||||
-- them. The operations are listed in chronological order.
|
||||
getMockOperations :: MockStorage -> IO [MockOperation]
|
||||
getMockOperations :: MockStorage sess -> IO [MockOperation sess]
|
||||
getMockOperations = flip I.atomicModifyIORef' ((,) [] . reverse) . mockOperations
|
||||
|
||||
|
||||
-- | Add a mock operations to the log.
|
||||
addMockOperation :: MockStorage -> MockOperation -> IO ()
|
||||
addMockOperation :: MockStorage sess -> MockOperation sess -> IO ()
|
||||
addMockOperation sto op = I.atomicModifyIORef' (mockOperations sto) $ \ops -> (op:ops, ())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user