From 91fa2581933e5615c7dc03e04efe5f455c5fdcfe Mon Sep 17 00:00:00 2001 From: Matt Parsons Date: Mon, 28 Oct 2019 14:06:01 -0600 Subject: [PATCH] Fix the On Clause Ordering issue (#156) * Add failing test * Refactor newIdentFor to not have an error case * annotation for warning * refactoring * Expression parser * holy shit it works * Add a shitload of tests * cross join * Find a failing case * Account for that one case * works * Composability test * okay now it tests something * Documentation updates * Add since, changelog * fix --- changelog.md | 10 +- esqueleto.cabal | 13 +- src/Database/Esqueleto.hs | 7 +- src/Database/Esqueleto/Internal/ExprParser.hs | 82 +++ src/Database/Esqueleto/Internal/Internal.hs | 274 ++++++--- src/Database/Esqueleto/Internal/Sql.hs | 2 + test/Common/Test.hs | 545 +++++++++++++++++- test/PostgreSQL/Test.hs | 8 +- 8 files changed, 837 insertions(+), 104 deletions(-) create mode 100644 src/Database/Esqueleto/Internal/ExprParser.hs diff --git a/changelog.md b/changelog.md index e43a6f8..95f97bc 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,11 @@ +3.2.0 (unreleased) +======= + +- @parsonsmatt + - [#156](https://github.com/bitemyapp/esqueleto/pull/156): Remove the + restriction that `on` clauses must appear in reverse order to the joining + tables. + 3.1.3 ======== @@ -11,7 +19,7 @@ - [#149](https://github.com/bitemyapp/esqueleto/pull/157): Added `associateJoin` query helpers. 3.1.1 -======== +======= - @JoseD92 - [#149](https://github.com/bitemyapp/esqueleto/pull/149): Added `upsert` support. diff --git a/esqueleto.cabal b/esqueleto.cabal index 239e9ed..e71220f 100644 --- a/esqueleto.cabal +++ b/esqueleto.cabal @@ -1,7 +1,7 @@ cabal-version: 1.12 name: esqueleto -version: 3.1.3 +version: 3.2.0 synopsis: Type-safe EDSL for SQL queries on persistent backends. description: @esqueleto@ is a bare bones, type-safe EDSL for SQL queries that works with unmodified @persistent@ SQL backends. Its language closely resembles SQL, so you don't have to learn new concepts, just new syntax, and it's fairly easy to predict the generated SQL and optimize it for your backend. Most kinds of errors committed when writing SQL are caught as compile-time errors---although it is possible to write type-checked @esqueleto@ queries that fail at runtime. . @@ -31,12 +31,13 @@ library Database.Esqueleto Database.Esqueleto.Internal.Language Database.Esqueleto.Internal.Sql + Database.Esqueleto.Internal.Internal + Database.Esqueleto.Internal.ExprParser Database.Esqueleto.MySQL Database.Esqueleto.PostgreSQL Database.Esqueleto.PostgreSQL.JSON Database.Esqueleto.SQLite other-modules: - Database.Esqueleto.Internal.Internal Database.Esqueleto.Internal.PersistentImport Database.Esqueleto.PostgreSQL.JSON.Instances Paths_esqueleto @@ -45,10 +46,12 @@ library build-depends: base >=4.8 && <5.0 , aeson >=1.0 + , attoparsec >= 0.13 && < 0.14 , blaze-html , bytestring , containers , conduit >=1.3 + , containers , monad-logger , persistent >=2.10.0 && <2.11 , resourcet >=1.2 @@ -75,6 +78,7 @@ test-suite mysql ghc-options: -Wall build-depends: base >=4.8 && <5.0 + , attoparsec , blaze-html , bytestring , conduit >=1.3 @@ -83,6 +87,7 @@ test-suite mysql , exceptions , hspec , monad-logger + , mtl , mysql , mysql-simple , persistent >=2.8.0 && <2.11 @@ -110,6 +115,7 @@ test-suite postgresql build-depends: base >=4.8 && <5.0 , aeson + , attoparsec , blaze-html , bytestring , conduit >=1.3 @@ -118,6 +124,7 @@ test-suite postgresql , exceptions , hspec , monad-logger + , mtl , persistent >=2.10.0 && <2.11 , persistent-postgresql >= 2.10.0 && <2.11 , persistent-template @@ -144,6 +151,7 @@ test-suite sqlite ghc-options: -Wall build-depends: base >=4.8 && <5.0 + , attoparsec , blaze-html , bytestring , conduit >=1.3 @@ -152,6 +160,7 @@ test-suite sqlite , exceptions , hspec , monad-logger + , mtl , persistent >=2.8.0 && <2.11 , persistent-sqlite , persistent-template diff --git a/src/Database/Esqueleto.hs b/src/Database/Esqueleto.hs index 9dce2ff..3230a7a 100644 --- a/src/Database/Esqueleto.hs +++ b/src/Database/Esqueleto.hs @@ -327,16 +327,11 @@ import qualified Database.Persist -- @ -- 'select' $ -- 'from' $ \\(p1 `'InnerJoin`` f `'InnerJoin`` p2) -> do --- 'on' (p2 '^.' PersonId '==.' f '^.' FollowFollowed) -- 'on' (p1 '^.' PersonId '==.' f '^.' FollowFollower) +-- 'on' (p2 '^.' PersonId '==.' f '^.' FollowFollowed) -- return (p1, f, p2) -- @ -- --- /Note carefully that the order of the ON clauses is/ --- /reversed!/ You're required to write your 'on's in reverse --- order because that helps composability (see the documentation --- of 'on' for more details). --- -- We also currently support @UPDATE@ and @DELETE@ statements. -- For example: -- diff --git a/src/Database/Esqueleto/Internal/ExprParser.hs b/src/Database/Esqueleto/Internal/ExprParser.hs new file mode 100644 index 0000000..8806a6b --- /dev/null +++ b/src/Database/Esqueleto/Internal/ExprParser.hs @@ -0,0 +1,82 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} + +-- | This is an internal module. This module may have breaking changes without +-- a corresponding major version bump. If you use this module, please open an +-- issue with your use-case so we can safely support it. +module Database.Esqueleto.Internal.ExprParser where + +import Prelude hiding (takeWhile) + +import Control.Applicative ((<|>)) +import Control.Monad (void) +import Data.Attoparsec.Text +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Text (Text) +import qualified Data.Text as Text +import Database.Persist.Sql + +-- | A type representing the access of a table value. In Esqueleto, we get +-- a guarantee that the access will look something like: +-- +-- @ +-- escape-char [character] escape-char . escape-char [character] escape-char +-- ^^^^^^^^^^^ ^^^^^^^^^^^ +-- table name column name +-- @ +data TableAccess = TableAccess + { tableAccessTable :: Text + , tableAccessColumn :: Text + } + deriving (Eq, Ord, Show) + +-- | Parse a @SqlExpr (Value Bool)@'s textual representation into a list of +-- 'TableAccess' +parseOnExpr :: SqlBackend -> Text -> Either String (Set TableAccess) +parseOnExpr sqlBackend text = do + c <- mkEscapeChar sqlBackend + parseOnly (onExpr c) text + +-- | This function uses the 'connEscapeName' function in the 'SqlBackend' with an +-- empty identifier to pull out an escape character. This implementation works +-- with postgresql, mysql, and sqlite backends. +mkEscapeChar :: SqlBackend -> Either String Char +mkEscapeChar sqlBackend = + case Text.uncons (connEscapeName sqlBackend (DBName "")) of + Nothing -> + Left "Failed to get an escape character from the SQL backend." + Just (c, _) -> + Right c + +type ExprParser a = Char -> Parser a + +onExpr :: ExprParser (Set TableAccess) +onExpr e = Set.fromList <$> many' tableAccesses + where + tableAccesses = do + skipToEscape e "Skipping to an escape char" + parseTableAccess e "Parsing a table access" + +skipToEscape :: ExprParser () +skipToEscape escapeChar = void (takeWhile (/= escapeChar)) + +parseEscapedIdentifier :: ExprParser [Char] +parseEscapedIdentifier escapeChar = do + char escapeChar + str <- parseEscapedChars escapeChar + char escapeChar + pure str + +parseTableAccess :: ExprParser TableAccess +parseTableAccess ec = do + tableAccessTable <- Text.pack <$> parseEscapedIdentifier ec + _ <- char '.' + tableAccessColumn <- Text.pack <$> parseEscapedIdentifier ec + pure TableAccess {..} + +parseEscapedChars :: ExprParser [Char] +parseEscapedChars escapeChar = go + where + twoEscapes = char escapeChar *> char escapeChar + go = many' (notChar escapeChar <|> twoEscapes) diff --git a/src/Database/Esqueleto/Internal/Internal.hs b/src/Database/Esqueleto/Internal/Internal.hs index 5ba734d..0e14f55 100644 --- a/src/Database/Esqueleto/Internal/Internal.hs +++ b/src/Database/Esqueleto/Internal/Internal.hs @@ -25,11 +25,16 @@ -- | This is an internal module, anything exported by this module -- may change without a major version bump. Please use only -- "Database.Esqueleto" if possible. +-- +-- If you use this module, please report what your use case is on the issue +-- tracker so we can safely support it. module Database.Esqueleto.Internal.Internal where +import Control.Applicative ((<|>)) import Control.Arrow ((***), first) import Control.Exception (Exception, throw, throwIO) -import Control.Monad (ap, MonadPlus(..), void) +import qualified Data.Maybe as Maybe +import Control.Monad (guard, ap, MonadPlus(..), void) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Resource (MonadResource, release) @@ -43,6 +48,8 @@ import qualified Data.Monoid as Monoid import Data.Proxy (Proxy(..)) import Database.Esqueleto.Internal.PersistentImport import Database.Persist.Sql.Util (entityColumnNames, entityColumnCount, parseEntityValues, isIdField, hasCompositeKey) +import qualified Data.Set as Set +import Data.Set (Set) import qualified Control.Monad.Trans.Reader as R import qualified Control.Monad.Trans.State as S import qualified Control.Monad.Trans.Writer as W @@ -57,6 +64,8 @@ import qualified Data.Text.Lazy.Builder as TLB import Data.Typeable (Typeable) import Text.Blaze.Html (Html) +import Database.Esqueleto.Internal.ExprParser (TableAccess(..), parseOnExpr) + -- | (Internal) Start a 'from' query with an entity. 'from' -- does two kinds of magic using 'fromStart', 'fromJoin' and -- 'fromFinish': @@ -80,8 +89,8 @@ fromStart = x let ed = entityDef (getVal x) ident <- newIdentFor (entityDB ed) let ret = EEntity ident - from_ = FromStart ident ed - return (EPreprocessedFrom ret from_) + f' = FromStart ident ed + return (EPreprocessedFrom ret f') getVal :: SqlQuery (SqlExpr (PreprocessedFrom (SqlExpr (Entity a)))) -> Proxy a getVal = const Proxy @@ -94,7 +103,7 @@ fromStartMaybe = maybelize <$> fromStart where maybelize :: SqlExpr (PreprocessedFrom (SqlExpr (Entity a))) -> SqlExpr (PreprocessedFrom (SqlExpr (Maybe (Entity a)))) - maybelize (EPreprocessedFrom ret from_) = EPreprocessedFrom (EMaybe ret) from_ + maybelize (EPreprocessedFrom ret f') = EPreprocessedFrom (EMaybe ret) f' -- | (Internal) Do a @JOIN@. fromJoin @@ -105,71 +114,75 @@ fromJoin fromJoin (EPreprocessedFrom lhsRet lhsFrom) (EPreprocessedFrom rhsRet rhsFrom) = Q $ do let ret = smartJoin lhsRet rhsRet - from_ = FromJoin lhsFrom -- LHS + from' = FromJoin lhsFrom -- LHS (reifyJoinKind ret) -- JOIN rhsFrom -- RHS Nothing -- ON - return (EPreprocessedFrom ret from_) + return (EPreprocessedFrom ret from') -- | (Internal) Finish a @JOIN@. fromFinish :: SqlExpr (PreprocessedFrom a) -> SqlQuery a -fromFinish (EPreprocessedFrom ret from_) = Q $ do - W.tell mempty { sdFromClause = [from_] } +fromFinish (EPreprocessedFrom ret f') = Q $ do + W.tell mempty { sdFromClause = [f'] } return ret -- | @WHERE@ clause: restrict the query's result. where_ :: SqlExpr (Value Bool) -> SqlQuery () where_ expr = Q $ W.tell mempty { sdWhereClause = Where expr } --- | @ON@ clause: restrict the a @JOIN@'s result. The @ON@ --- clause will be applied to the /last/ @JOIN@ that does not --- have an @ON@ clause yet. If there are no @JOIN@s without --- @ON@ clauses (either because you didn't do any @JOIN@, or --- because all @JOIN@s already have their own @ON@ clauses), a --- runtime exception 'OnClauseWithoutMatchingJoinException' is --- thrown. @ON@ clauses are optional when doing @JOIN@s. +-- | An @ON@ clause, useful to describe how two tables are related. Cross joins +-- and tuple-joins do not need an 'on' clause, but 'InnerJoin' and the various +-- outer joins do. -- --- On the simple case of doing just one @JOIN@, for example +-- If you don't include an 'on' clause (or include too many!) then a runtime +-- exception will be thrown. +-- +-- As an example, consider this simple join: -- -- @ --- select $ +-- 'select' $ -- 'from' $ \\(foo `'InnerJoin`` bar) -> do -- 'on' (foo '^.' FooId '==.' bar '^.' BarFooId) -- ... -- @ -- --- there's no ambiguity and the rules above just mean that --- you're allowed to call 'on' only once (as in SQL). If you --- have many joins, then the 'on's are applied on the /reverse/ --- order that the @JOIN@s appear. For example: +-- We need to specify the clause for joining the two columns together. If we had +-- this: -- -- @ --- select $ +-- 'select' $ +-- 'from' $ \\(foo `'CrossJoin`` bar) -> do +-- ... +-- @ +-- +-- Then we can safely omit the 'on' clause, because the cross join will make +-- pairs of all records possible. +-- +-- You can do multiple 'on' clauses in a query. This query joins three tables, +-- and has two 'on' clauses: +-- +-- @ +-- 'select' $ -- 'from' $ \\(foo `'InnerJoin`` bar `'InnerJoin`` baz) -> do -- 'on' (baz '^.' BazId '==.' bar '^.' BarBazId) -- 'on' (foo '^.' FooId '==.' bar '^.' BarFooId) -- ... -- @ -- --- The order is /reversed/ in order to improve composability. --- For example, consider @query1@ and @query2@ below: +-- Old versions of esqueleto required that you provide the 'on' clauses in +-- reverse order. This restriction has been lifted - you can now provide 'on' +-- clauses in any order, and the SQL should work itself out. The above query is +-- now totally equivalent to this: -- -- @ --- let query1 = --- 'from' $ \\(foo `'InnerJoin`` bar) -> do --- 'on' (foo '^.' FooId '==.' bar '^.' BarFooId) --- query2 = --- 'from' $ \\(mbaz `'LeftOuterJoin`` quux) -> do --- return (mbaz '?.' BazName, quux) --- test1 = (,) \<$\> query1 \<*\> query2 --- test2 = flip (,) \<$\> query2 \<*\> query1 +-- 'select' $ +-- 'from' $ \\(foo `'InnerJoin`` bar `'InnerJoin`` baz) -> do +-- 'on' (foo '^.' FooId '==.' bar '^.' BarFooId) +-- 'on' (baz '^.' BazId '==.' bar '^.' BarBazId) +-- ... -- @ --- --- If the order was /not/ reversed, then @test2@ would be --- broken: @query1@'s 'on' would refer to @query2@'s --- 'LeftOuterJoin'. on :: SqlExpr (Value Bool) -> SqlQuery () on expr = Q $ W.tell mempty { sdFromClause = [OnClause expr] } @@ -1396,32 +1409,108 @@ newtype SetClause = SetClause (SqlExpr (Value ())) -- | Collect 'OnClause's on 'FromJoin's. Returns the first -- unmatched 'OnClause's data on error. Returns a list without -- 'OnClauses' on success. -collectOnClauses :: [FromClause] -> Either (SqlExpr (Value Bool)) [FromClause] -collectOnClauses = go [] +collectOnClauses + :: SqlBackend + -> [FromClause] + -> Either (SqlExpr (Value Bool)) [FromClause] +collectOnClauses sqlBackend = go Set.empty [] where - go [] (f@(FromStart _ _):fs) = fmap (f:) (go [] fs) -- fast path - go acc (OnClause expr :fs) = findMatching acc expr >>= flip go fs - go acc (f:fs) = go (f:acc) fs - go acc [] = return $ reverse acc + go is [] (f@(FromStart i _) : fs) = + fmap (f:) (go (Set.insert i is) [] fs) -- fast path + go idents acc (OnClause expr : fs) = do + (idents', a) <- findMatching idents acc expr + go idents' a fs + go idents acc (f:fs) = + go idents (f:acc) fs + go _ acc [] = + return $ reverse acc - findMatching (f : acc) expr = - case tryMatch expr f of - Just f' -> return (f' : acc) - Nothing -> (f:) <$> findMatching acc expr - findMatching [] expr = Left expr + findMatching + :: Set Ident + -> [FromClause] + -> SqlExpr (Value Bool) + -> Either (SqlExpr (Value Bool)) (Set Ident, [FromClause]) + findMatching idents fromClauses expr = + case fromClauses of + f : acc -> + let + idents' = + idents + <> Set.fromList (Maybe.catMaybes [findLeftmostIdent f, findRightmostIdent f]) + in + case tryMatch idents' expr f of + Just (idents'', f') -> + return (idents'', f' : acc) + Nothing -> + fmap (f:) <$> findMatching idents' acc expr + [] -> + Left expr - tryMatch expr (FromJoin l k r onClause) = - matchR `mplus` matchC `mplus` matchL -- right to left - where - matchR = (\r' -> FromJoin l k r' onClause) <$> tryMatch expr r - matchL = (\l' -> FromJoin l' k r onClause) <$> tryMatch expr l - matchC = case onClause of - Nothing | k /= CrossJoinKind - -> return (FromJoin l k r (Just expr)) - | otherwise -> mzero - Just _ -> mzero - tryMatch _ _ = mzero + findRightmostIdent (FromStart i _) = Just i + findRightmostIdent (FromJoin _ _ r _) = findRightmostIdent r + findRightmostIdent (OnClause {}) = Nothing + findLeftmostIdent (FromStart i _) = Just i + findLeftmostIdent (FromJoin l _ _ _) = findLeftmostIdent l + findLeftmostIdent (OnClause {}) = Nothing + + tryMatch + :: Set Ident + -> SqlExpr (Value Bool) + -> FromClause + -> Maybe (Set Ident, FromClause) + tryMatch idents expr fromClause = + case fromClause of + FromJoin l k r onClause -> + matchTable <|> matchR <|> matchC <|> matchL <|> matchPartial -- right to left + where + matchR = fmap (\r' -> FromJoin l k r' onClause) + <$> tryMatch idents expr r + matchL = fmap (\l' -> FromJoin l' k r onClause) + <$> tryMatch idents expr l + matchPartial = do + i1 <- findLeftmostIdent l + i2 <- findRightmostIdent r + guard $ + Set.isSubsetOf + identsInOnClause + (Set.fromList [i1, i2]) + guard $ k /= CrossJoinKind + guard $ Maybe.isNothing onClause + pure (Set.fromList [] <> idents, FromJoin l k r (Just expr)) + matchC = + case onClause of + Nothing + | "?" `T.isInfixOf` renderedExpr -> + return (idents, FromJoin l k r (Just expr)) + | Set.null identsInOnClause -> + return (idents, FromJoin l k r (Just expr)) + | otherwise -> + Nothing + Just _ -> + Nothing + matchTable = do + i1 <- findLeftmostIdent r + i2 <- findRightmostIdent l + guard $ Set.fromList [i1, i2] `Set.isSubsetOf` identsInOnClause + guard $ k /= CrossJoinKind + guard $ Maybe.isNothing onClause + pure (Set.fromList [i1, i2] <> idents, FromJoin l k r (Just expr)) + + _ -> + Nothing + where + identsInOnClause = + onExprToTableIdentifiers + + renderedExpr = + renderExpr sqlBackend expr + + onExprToTableIdentifiers = + Set.map (I . tableAccessTable) + . either error id + . parseOnExpr sqlBackend + $ renderedExpr -- | A complete @WHERE@ clause. data WhereClause = Where (SqlExpr (Value Bool)) @@ -1476,6 +1565,7 @@ type LockingClause = Monoid.Last LockingKind -- | Identifier used for table names. newtype Ident = I T.Text + deriving (Eq, Ord, Show) -- | List of identifiers already in use and supply of temporary @@ -1489,24 +1579,18 @@ initialIdentState = IdentState mempty -- | Create a fresh 'Ident'. If possible, use the given -- 'DBName'. newIdentFor :: DBName -> SqlQuery Ident -newIdentFor = Q . lift . try . unDBName +newIdentFor (DBName original) = Q $ lift $ findFree Nothing where - try orig = do - s <- S.get - let go (t:ts) | t `HS.member` inUse s = go ts - | otherwise = use t - go [] = throw (UnexpectedCaseErr NewIdentForError) - go (possibilities orig) - - possibilities t = t : map addNum [2..] - where - addNum :: Int -> T.Text - addNum = T.append t . T.pack . show - - use t = do - S.modify (\s -> s { inUse = HS.insert t (inUse s) }) - return (I t) - + findFree msuffix = do + let + withSuffix = + maybe id (\suffix -> (<> T.pack (show suffix))) msuffix original + isInUse <- S.gets (HS.member withSuffix . inUse) + if isInUse + then findFree (succ <$> (msuffix <|> Just (1 :: Int))) + else do + S.modify (\s -> s { inUse = HS.insert withSuffix (inUse s) }) + pure (I withSuffix) -- | Information needed to escape and use identifiers. type IdentInfo = (SqlBackend, IdentState) @@ -1914,7 +1998,7 @@ selectSource query = do -- @Value t@. You may use @Value@ to return projections of an -- @Entity@ (see @('^.')@ and @('?.')@) or to return any other -- value calculated on the query (e.g., 'countRows' or --- 'sub_select'). +-- 'subSelect'). -- -- The @SqlSelect a r@ class has functional dependencies that -- allow type information to flow both from @a@ to @r@ and @@ -2209,11 +2293,15 @@ makeSelect info mode_ distinctClause ret = process mode_ plain v = (v, []) -makeFrom :: IdentInfo -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) +makeFrom + :: IdentInfo + -> Mode + -> [FromClause] + -> (TLB.Builder, [PersistValue]) makeFrom _ _ [] = mempty makeFrom info mode fs = ret where - ret = case collectOnClauses fs of + ret = case collectOnClauses (fst info) fs of Left expr -> throw $ mkExc expr Right fs' -> keyword $ uncommas' (map (mk Never) fs') keyword = case mode of @@ -2932,3 +3020,33 @@ insertSelect = void . insertSelectCount insertSelectCount :: (MonadIO m, PersistEntity a) => SqlQuery (SqlExpr (Insertion a)) -> SqlWriteT m Int64 insertSelectCount = rawEsqueleto INSERT_INTO . fmap EInsertFinal + +-- | Renders an expression into 'Text'. Only useful for creating a textual +-- representation of the clauses passed to an "On" clause. +-- +-- @since 3.2.0 +renderExpr :: SqlBackend -> SqlExpr (Value Bool) -> T.Text +renderExpr sqlBackend e = + case e of + ERaw _ mkBuilderValues -> do + let (builder, _) = mkBuilderValues (sqlBackend, initialIdentState) + in (builderToText builder) + ECompositeKey mkInfo -> + throw + . RenderExprUnexpectedECompositeKey + . builderToText + . mconcat + . mkInfo + $ (sqlBackend, initialIdentState) + +-- | An exception thrown by 'RenderExpr' - it's not designed to handle composite +-- keys, and will blow up if you give it one. +-- +-- @since 3.2.0 +data RenderExprException = RenderExprUnexpectedECompositeKey T.Text + deriving Show + +-- | +-- +-- @since 3.2.0 +instance Exception RenderExprException diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 7818624..f1610c4 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -55,6 +55,7 @@ module Database.Esqueleto.Internal.Sql , Mode(..) , NeedParens(..) , IdentState + , renderExpr , initialIdentState , IdentInfo , SqlSelect(..) @@ -71,6 +72,7 @@ module Database.Esqueleto.Internal.Sql , parens , toArgList , builderToText + , Ident(..) ) where import Database.Esqueleto.Internal.Internal diff --git a/test/Common/Test.hs b/test/Common/Test.hs index 4f55d1b..8bef42d 100644 --- a/test/Common/Test.hs +++ b/test/Common/Test.hs @@ -2,6 +2,7 @@ {-# OPTIONS_GHC -fno-warn-deprecations #-} {-# LANGUAGE ConstraintKinds , CPP + , PartialTypeSignatures , UndecidableInstances , EmptyDataDecls , FlexibleContexts @@ -55,7 +56,9 @@ module Common.Test , Unique(..) ) where +import Data.Either import Control.Monad (forM_, replicateM, replicateM_, void) +import Control.Monad.Reader (ask) import Control.Monad.Catch (MonadCatch) #if __GLASGOW_HASKELL__ >= 806 import Control.Monad.Fail (MonadFail) @@ -69,8 +72,8 @@ import Database.Esqueleto import Database.Persist.TH import Test.Hspec import UnliftIO +import qualified Data.Attoparsec.Text as AP -import Database.Persist (PersistValue(..)) import Data.Conduit (ConduitT, (.|), runConduit) import qualified Data.Conduit.List as CL import qualified Data.List as L @@ -80,16 +83,17 @@ import qualified Data.Text.Lazy.Builder as TLB import qualified Data.Text.Internal.Lazy as TL import qualified Database.Esqueleto.Internal.Sql as EI import qualified UnliftIO.Resource as R - - +import qualified Database.Esqueleto.Internal.ExprParser as P -- Test schema share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistUpperCase| Foo name Int Primary name + deriving Show Eq Ord Bar quux FooId + deriving Show Eq Ord Person name String @@ -101,6 +105,18 @@ share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistUpperCase| title String authorId PersonId deriving Eq Show + Comment + body String + blog BlogPostId + deriving Eq Show + Profile + name String + person PersonId + deriving Eq Show + Reply + guy PersonId + body String + deriving Eq Show Lord county String maxlen=100 @@ -160,6 +176,35 @@ share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistUpperCase| Numbers int Int double Double + + JoinOne + name String + deriving Eq Show + + JoinTwo + joinOne JoinOneId + name String + deriving Eq Show + + JoinThree + joinTwo JoinTwoId + name String + deriving Eq Show + + JoinFour + name String + joinThree JoinThreeId + deriving Eq Show + + JoinOther + name String + deriving Eq Show + + JoinMany + name String + joinOther JoinOtherId + joinOne JoinOneId + deriving Eq Show |] -- Unique Test schema @@ -317,6 +362,7 @@ testSelectFrom run = do , (p2e, p1e) , (p2e, p2e) ] + it "works for a self-join via sub_select" $ run $ do p1k <- insert p1 @@ -447,7 +493,7 @@ testSelectFrom run = do testSelectJoin :: Run -> Spec testSelectJoin run = do - describe "select/JOIN" $ do + describe "select:JOIN" $ do it "works with a LEFT OUTER JOIN" $ run $ do p1e <- insert' p1 @@ -604,11 +650,9 @@ testSelectJoin run = do return p liftIO $ (entityVal <$> ps) `shouldBe` [p1] - - testSelectWhere :: Run -> Spec testSelectWhere run = do - describe "select/where_" $ do + describe "select where_" $ do it "works for a simple example with (==.)" $ run $ do p1e <- insert' p1 @@ -828,6 +872,17 @@ testSelectWhere run = do , (p4e, f42, p2e) , (p2e, f21, p1e) ] + it "works for a many-to-many explicit join and on order doesn't matter" $ do + run $ void $ + selectRethrowingQuery $ + from $ \(person `InnerJoin` blog `InnerJoin` comment) -> do + on $ person ^. PersonId ==. blog ^. BlogPostAuthorId + on $ blog ^. BlogPostId ==. comment ^. CommentBlog + pure (person, comment) + + -- we only care that we don't have a SQL error + True `shouldBe` True + it "works for a many-to-many explicit join with LEFT OUTER JOINs" $ run $ do p1e@(Entity p1k _) <- insert' p1 @@ -1461,7 +1516,7 @@ testCountingRows run = do liftIO $ (n :: Int) `shouldBe` expected testRenderSql :: Run -> Spec -testRenderSql run = +testRenderSql run = do describe "testRenderSql" $ do it "works" $ do (queryText, queryVals) <- run $ renderQuerySelect $ @@ -1481,7 +1536,451 @@ testRenderSql run = `shouldBe` [toPersistValue ("Johhny Depp" :: TL.Text)] + describe "renderExpr" $ do + it "renders a value" $ do + (c, expr) <- run $ do + conn <- ask + let Right c = P.mkEscapeChar conn + pure $ (,) c $ EI.renderExpr conn $ + EI.EEntity (EI.I "user") ^. PersonId + ==. EI.EEntity (EI.I "blog_post") ^. BlogPostAuthorId + expr + `shouldBe` + Text.intercalate (Text.singleton c) ["", "user", ".", "id", ""] + <> + " = " + <> + Text.intercalate (Text.singleton c) ["", "blog_post", ".", "authorId", ""] + it "renders ? for a val" $ do + expr <- run $ ask >>= \c -> pure $ EI.renderExpr c (val (PersonKey 0) ==. val (PersonKey 1)) + expr `shouldBe` "? = ?" + describe "EEntity Ident behavior" $ do + let + render :: SqlExpr (Entity val) -> Text.Text + render (EI.EEntity (EI.I ident)) = ident + it "renders sensibly" $ do + results <- run $ do + _ <- insert $ Foo 2 + _ <- insert $ Foo 3 + _ <- insert $ Person "hello" Nothing Nothing 3 + select $ + from $ \(a `LeftOuterJoin` b) -> do + on $ a ^. FooName ==. b ^. PersonFavNum + pure (val (render a), val (render b)) + head results + `shouldBe` + (Value "Foo", Value "Person") + + describe "ExprParser" $ do + let parse parser = AP.parseOnly (parser '#') + describe "parseEscapedChars" $ do + let subject = parse P.parseEscapedChars + it "parses words" $ do + subject "hello world" + `shouldBe` + Right "hello world" + it "only returns a single escape-char if present" $ do + subject "i_am##identifier##" + `shouldBe` + Right "i_am#identifier#" + describe "parseEscapedIdentifier" $ do + let subject = parse P.parseEscapedIdentifier + it "parses the quotes out" $ do + subject "#it's a me, mario#" + `shouldBe` + Right "it's a me, mario" + it "requires a beginning and end quote" $ do + subject "#alas, i have no end" + `shouldSatisfy` + isLeft + describe "parseTableAccess" $ do + let subject = parse P.parseTableAccess + it "parses a table access" $ do + subject "#foo#.#bar#" + `shouldBe` + Right P.TableAccess + { P.tableAccessTable = "foo" + , P.tableAccessColumn = "bar" + } + describe "onExpr" $ do + let subject = parse P.onExpr + it "works" $ do + subject "#foo#.#bar# = #bar#.#baz#" + `shouldBe` do + Right $ S.fromList + [ P.TableAccess + { P.tableAccessTable = "foo" + , P.tableAccessColumn = "bar" + } + , P.TableAccess + { P.tableAccessTable = "bar" + , P.tableAccessColumn = "baz" + } + ] + it "also works with other nonsense" $ do + subject "#foo#.#bar# = 3" + `shouldBe` do + Right $ S.fromList + [ P.TableAccess + { P.tableAccessTable = "foo" + , P.tableAccessColumn = "bar" + } + ] + it "handles a conjunction" $ do + subject "#foo#.#bar# = #bar#.#baz# AND #bar#.#baz# > 10" + `shouldBe` do + Right $ S.fromList + [ P.TableAccess + { P.tableAccessTable = "foo" + , P.tableAccessColumn = "bar" + } + , P.TableAccess + { P.tableAccessTable = "bar" + , P.tableAccessColumn = "baz" + } + ] + it "handles ? okay" $ do + subject "#foo#.#bar# = ?" + `shouldBe` do + Right $ S.fromList + [ P.TableAccess + { P.tableAccessTable = "foo" + , P.tableAccessColumn = "bar" + } + ] + it "handles degenerate cases" $ do + subject "false" `shouldBe` pure mempty + subject "true" `shouldBe` pure mempty + subject "1 = 1" `shouldBe` pure mempty + it "works even if an identifier isn't first" $ do + subject "true and #foo#.#bar# = 2" + `shouldBe` do + Right $ S.fromList + [ P.TableAccess + { P.tableAccessTable = "foo" + , P.tableAccessColumn = "bar" + } + ] + +testOnClauseOrder :: Run -> Spec +testOnClauseOrder run = describe "On Clause Ordering" $ do + let + setup :: MonadIO m => SqlPersistT m () + setup = do + ja1 <- insert (JoinOne "j1 hello") + ja2 <- insert (JoinOne "j1 world") + jb1 <- insert (JoinTwo ja1 "j2 hello") + jb2 <- insert (JoinTwo ja1 "j2 world") + jb3 <- insert (JoinTwo ja2 "j2 foo") + _ <- insert (JoinTwo ja2 "j2 bar") + jc1 <- insert (JoinThree jb1 "j3 hello") + jc2 <- insert (JoinThree jb1 "j3 world") + _ <- insert (JoinThree jb2 "j3 foo") + _ <- insert (JoinThree jb3 "j3 bar") + _ <- insert (JoinThree jb3 "j3 baz") + _ <- insert (JoinFour "j4 foo" jc1) + _ <- insert (JoinFour "j4 bar" jc2) + jd1 <- insert (JoinOther "foo") + jd2 <- insert (JoinOther "bar") + _ <- insert (JoinMany "jm foo hello" jd1 ja1) + _ <- insert (JoinMany "jm foo world" jd1 ja2) + _ <- insert (JoinMany "jm bar hello" jd2 ja1) + _ <- insert (JoinMany "jm bar world" jd2 ja2) + pure () + describe "identical results for" $ do + it "three tables" $ do + abcs <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c) -> do + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + pure (a, b, c) + acbs <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c) -> do + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + pure (a, b, c) + + listsEqualOn abcs acbs $ \(Entity _ j1, Entity _ j2, Entity _ j3) -> + (joinOneName j1, joinTwoName j2, joinThreeName j3) + + it "four tables" $ do + xs0 <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c `InnerJoin` d) -> do + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + pure (a, b, c, d) + xs1 <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c `InnerJoin` d) -> do + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + pure (a, b, c, d) + xs2 <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c `InnerJoin` d) -> do + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + pure (a, b, c, d) + xs3 <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c `InnerJoin` d) -> do + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + pure (a, b, c, d) + xs4 <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c `InnerJoin` d) -> do + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + pure (a, b, c, d) + + let getNames (j1, j2, j3, j4) = + ( joinOneName (entityVal j1) + , joinTwoName (entityVal j2) + , joinThreeName (entityVal j3) + , joinFourName (entityVal j4) + ) + listsEqualOn xs0 xs1 getNames + listsEqualOn xs0 xs2 getNames + listsEqualOn xs0 xs3 getNames + listsEqualOn xs0 xs4 getNames + + it "associativity of innerjoin" $ do + xs0 <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c `InnerJoin` d) -> do + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + pure (a, b, c, d) + + xs1 <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` (c `InnerJoin` d)) -> do + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + pure (a, b, c, d) + + xs2 <- run $ do + setup + select $ + from $ \(a `InnerJoin` (b `InnerJoin` c) `InnerJoin` d) -> do + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + pure (a, b, c, d) + + xs3 <- run $ do + setup + select $ + from $ \(a `InnerJoin` (b `InnerJoin` c `InnerJoin` d)) -> do + on (a ^. JoinOneId ==. b ^. JoinTwoJoinOne) + on (b ^. JoinTwoId ==. c ^. JoinThreeJoinTwo) + on (c ^. JoinThreeId ==. d ^. JoinFourJoinThree) + pure (a, b, c, d) + + let getNames (j1, j2, j3, j4) = + ( joinOneName (entityVal j1) + , joinTwoName (entityVal j2) + , joinThreeName (entityVal j3) + , joinFourName (entityVal j4) + ) + listsEqualOn xs0 xs1 getNames + listsEqualOn xs0 xs2 getNames + listsEqualOn xs0 xs3 getNames + + it "inner join on two entities" $ do + (xs0, xs1) <- run $ do + pid <- insert $ Person "hello" Nothing Nothing 3 + _ <- insert $ BlogPost "good poast" pid + _ <- insert $ Profile "cool" pid + xs0 <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr) -> do + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + on $ p ^. PersonId ==. pr ^. ProfilePerson + pure (p, b, pr) + xs1 <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr) -> do + on $ p ^. PersonId ==. pr ^. ProfilePerson + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + pure (p, b, pr) + pure (xs0, xs1) + listsEqualOn xs0 xs1 $ \(Entity _ p, Entity _ b, Entity _ pr) -> + (personName p, blogPostTitle b, profileName pr) + it "inner join on three entities" $ do + res <- run $ do + pid <- insert $ Person "hello" Nothing Nothing 3 + _ <- insert $ BlogPost "good poast" pid + _ <- insert $ BlogPost "good poast #2" pid + _ <- insert $ Profile "cool" pid + _ <- insert $ Reply pid "u wot m8" + _ <- insert $ Reply pid "how dare you" + + bprr <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr `InnerJoin` r) -> do + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + on $ p ^. PersonId ==. pr ^. ProfilePerson + on $ p ^. PersonId ==. r ^. ReplyGuy + pure (p, b, pr, r) + + brpr <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr `InnerJoin` r) -> do + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + on $ p ^. PersonId ==. r ^. ReplyGuy + on $ p ^. PersonId ==. pr ^. ProfilePerson + pure (p, b, pr, r) + + prbr <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr `InnerJoin` r) -> do + on $ p ^. PersonId ==. pr ^. ProfilePerson + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + on $ p ^. PersonId ==. r ^. ReplyGuy + pure (p, b, pr, r) + + prrb <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr `InnerJoin` r) -> do + on $ p ^. PersonId ==. pr ^. ProfilePerson + on $ p ^. PersonId ==. r ^. ReplyGuy + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + pure (p, b, pr, r) + + rprb <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr `InnerJoin` r) -> do + on $ p ^. PersonId ==. r ^. ReplyGuy + on $ p ^. PersonId ==. pr ^. ProfilePerson + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + pure (p, b, pr, r) + + rbpr <- selectRethrowingQuery $ + from $ \(p `InnerJoin` b `InnerJoin` pr `InnerJoin` r) -> do + on $ p ^. PersonId ==. r ^. ReplyGuy + on $ p ^. PersonId ==. b ^. BlogPostAuthorId + on $ p ^. PersonId ==. pr ^. ProfilePerson + pure (p, b, pr, r) + + pure [bprr, brpr, prbr, prrb, rprb, rbpr] + forM_ (zip res (drop 1 (cycle res))) $ \(a, b) -> a `shouldBe` b + + it "many-to-many" $ do + ac <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c) -> do + on (a ^. JoinOneId ==. b ^. JoinManyJoinOne) + on (c ^. JoinOtherId ==. b ^. JoinManyJoinOther) + pure (a, c) + + ca <- run $ do + setup + select $ + from $ \(a `InnerJoin` b `InnerJoin` c) -> do + on (c ^. JoinOtherId ==. b ^. JoinManyJoinOther) + on (a ^. JoinOneId ==. b ^. JoinManyJoinOne) + pure (a, c) + + listsEqualOn ac ca $ \(Entity _ a, Entity _ b) -> + (joinOneName a, joinOtherName b) + + it "left joins on order" $ do + ca <- run $ do + setup + select $ + from $ \(a `LeftOuterJoin` b `InnerJoin` c) -> do + on (c ?. JoinOtherId ==. b ?. JoinManyJoinOther) + on (just (a ^. JoinOneId) ==. b ?. JoinManyJoinOne) + orderBy [asc $ a ^. JoinOneId, asc $ c ?. JoinOtherId] + pure (a, c) + ac <- run $ do + setup + select $ + from $ \(a `LeftOuterJoin` b `InnerJoin` c) -> do + on (just (a ^. JoinOneId) ==. b ?. JoinManyJoinOne) + on (c ?. JoinOtherId ==. b ?. JoinManyJoinOther) + orderBy [asc $ a ^. JoinOneId, asc $ c ?. JoinOtherId] + pure (a, c) + + listsEqualOn ac ca $ \(Entity _ a, b) -> + (joinOneName a, maybe "NULL" (joinOtherName . entityVal) b) + + it "doesn't require an on for a crossjoin" $ do + void $ run $ + select $ + from $ \(a `CrossJoin` b) -> do + pure (a :: SqlExpr (Entity JoinOne), b :: SqlExpr (Entity JoinTwo)) + + it "errors with an on for a crossjoin" $ do + (void $ run $ + select $ + from $ \(a `CrossJoin` b) -> do + on $ a ^. JoinOneId ==. b ^. JoinTwoJoinOne + pure (a, b)) + `shouldThrow` \(OnClauseWithoutMatchingJoinException _) -> + True + + it "left joins associativity" $ do + ca <- run $ do + setup + select $ + from $ \(a `LeftOuterJoin` (b `InnerJoin` c)) -> do + on (c ?. JoinOtherId ==. b ?. JoinManyJoinOther) + on (just (a ^. JoinOneId) ==. b ?. JoinManyJoinOne) + orderBy [asc $ a ^. JoinOneId, asc $ c ?. JoinOtherId] + pure (a, c) + ca' <- run $ do + setup + select $ + from $ \(a `LeftOuterJoin` b `InnerJoin` c) -> do + on (c ?. JoinOtherId ==. b ?. JoinManyJoinOther) + on (just (a ^. JoinOneId) ==. b ?. JoinManyJoinOne) + orderBy [asc $ a ^. JoinOneId, asc $ c ?. JoinOtherId] + pure (a, c) + + listsEqualOn ca ca' $ \(Entity _ a, b) -> + (joinOneName a, maybe "NULL" (joinOtherName . entityVal) b) + + it "composes queries still" $ do + let + query1 = + from $ \(foo `InnerJoin` bar) -> do + on (foo ^. FooId ==. bar ^. BarQuux) + pure (foo, bar) + query2 = + from $ \(p `LeftOuterJoin` bp) -> do + on (p ^. PersonId ==. bp ^. BlogPostAuthorId) + pure (p, bp) + (a, b) <- run $ do + fid <- insert $ Foo 5 + _ <- insert $ Bar fid + pid <- insert $ Person "hey" Nothing Nothing 30 + _ <- insert $ BlogPost "WHY" pid + a <- select ((,) <$> query1 <*> query2) + b <- select (flip (,) <$> query1 <*> query2) + pure (a, b) + listsEqualOn a (map (\(x, y) -> (y, x)) b) id + + + +listsEqualOn :: (Show a1, Eq a1) => [a2] -> [a2] -> (a2 -> a1) -> Expectation +listsEqualOn a b f = map f a `shouldBe` map f b tests :: Run -> Spec tests run = do @@ -1503,6 +2002,7 @@ tests run = do testCase run testCountingRows run testRenderSql run + testOnClauseOrder run insert' :: ( Functor m @@ -1535,12 +2035,15 @@ cleanDB :: (forall m. RunDbMonad m => SqlPersistT (R.ResourceT m) ()) cleanDB = do - delete $ from $ \(_ :: SqlExpr (Entity Foo)) -> return () delete $ from $ \(_ :: SqlExpr (Entity Bar)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity Foo)) -> return () - delete $ from $ \(_ :: SqlExpr (Entity BlogPost)) -> return () - delete $ from $ \(_ :: SqlExpr (Entity Follow)) -> return () - delete $ from $ \(_ :: SqlExpr (Entity Person)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity Reply)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity Comment)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity Profile)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity BlogPost)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity Follow)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity Person)) -> return () delete $ from $ \(_ :: SqlExpr (Entity Deed)) -> return () delete $ from $ \(_ :: SqlExpr (Entity Lord)) -> return () @@ -1557,10 +2060,26 @@ cleanDB = do delete $ from $ \(_ :: SqlExpr (Entity Point)) -> return () delete $ from $ \(_ :: SqlExpr (Entity Numbers)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity JoinMany)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity JoinFour)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity JoinThree)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity JoinTwo)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity JoinOne)) -> return () + delete $ from $ \(_ :: SqlExpr (Entity JoinOther)) -> return () cleanUniques :: (forall m. RunDbMonad m => SqlPersistT (R.ResourceT m) ()) cleanUniques = - delete $ from $ \(_ :: SqlExpr (Entity OneUnique)) -> return () \ No newline at end of file + delete $ from $ \(_ :: SqlExpr (Entity OneUnique)) -> return () + +selectRethrowingQuery + :: (MonadIO m, EI.SqlSelect a r, MonadUnliftIO m) + => SqlQuery a + -> SqlPersistT m [r] +selectRethrowingQuery query = + select query + `catch` \(SomeException e) -> do + (text, _) <- renderQuerySelect query + liftIO . throwIO . userError $ Text.unpack text <> "\n\n" <> show e diff --git a/test/PostgreSQL/Test.hs b/test/PostgreSQL/Test.hs index 4ef69c4..b008f13 100644 --- a/test/PostgreSQL/Test.hs +++ b/test/PostgreSQL/Test.hs @@ -958,10 +958,10 @@ testInsertUniqueViolation = insert u3) `shouldThrow` (==) exception where exception = SqlError { - sqlState = "23505", - sqlExecStatus = FatalError, - sqlErrorMsg = "duplicate key value violates unique constraint \"UniqueValue\"", - sqlErrorDetail = "Key (value)=(0) already exists.", + sqlState = "23505", + sqlExecStatus = FatalError, + sqlErrorMsg = "duplicate key value violates unique constraint \"UniqueValue\"", + sqlErrorDetail = "Key (value)=(0) already exists.", sqlErrorHint = ""} testUpsert :: Spec