tests pass locally

This commit is contained in:
parsonsmatt 2021-03-26 16:38:32 -06:00
parent 4a546d2698
commit 2b5da6ab6f
2 changed files with 28 additions and 19 deletions

View File

@ -30,6 +30,7 @@ import Common.Test (RunDbMonad)
share [mkPersist sqlSettings, mkMigrate "migrateJSON"] [persistUpperCase| share [mkPersist sqlSettings, mkMigrate "migrateJSON"] [persistUpperCase|
Json Json
value (JSONB Value) value (JSONB Value)
deriving Show
|] |]
cleanJSON cleanJSON

View File

@ -18,7 +18,7 @@ import Data.Map (Map)
import Data.Time import Data.Time
import Control.Arrow ((&&&)) import Control.Arrow ((&&&))
import Control.Monad (void, when) import Control.Monad (void, when)
import Control.Monad.Catch (MonadCatch, catch) import Control.Monad.Catch
import Control.Monad.IO.Class (MonadIO(liftIO)) import Control.Monad.IO.Class (MonadIO(liftIO))
import Control.Monad.Logger (runStderrLoggingT, runNoLoggingT) import Control.Monad.Logger (runStderrLoggingT, runNoLoggingT)
import Control.Monad.Trans.Reader (ReaderT, ask) import Control.Monad.Trans.Reader (ReaderT, ask)
@ -45,6 +45,7 @@ import Database.PostgreSQL.Simple (SqlError(..), ExecStatus(..))
import System.Environment import System.Environment
import Test.Hspec import Test.Hspec
import Test.Hspec.QuickCheck import Test.Hspec.QuickCheck
import qualified Data.ByteString.Lazy as BSL
import Common.Test import Common.Test
import PostgreSQL.MigrateJSON import PostgreSQL.MigrateJSON
@ -894,9 +895,6 @@ testConcatenationOperator =
liftIO $ length y `shouldBe` 1 liftIO $ length y `shouldBe` 1
liftIO $ length z `shouldBe` 2 liftIO $ length z `shouldBe` 2
liftIO $ length w `shouldBe` 7 liftIO $ length w `shouldBe` 7
sqlFailWith "22023" $ selectJSONwhere $ \v ->
v JSON.||. jsonbVal (toJSON $ String "test")
@>. jsonbVal (String "test")
testMinusOperator :: Spec testMinusOperator :: Spec
testMinusOperator = testMinusOperator =
@ -981,14 +979,14 @@ testHashMinusOperator =
createSaneSQL @JSONValue createSaneSQL @JSONValue
(jsonbVal (object ["a" .= False, "b" .= True]) #-. ["a"]) (jsonbVal (object ["a" .= False, "b" .= True]) #-. ["a"])
"SELECT (? #- ?)\nFROM \"Json\"\n" "SELECT (? #- ?)\nFROM \"Json\"\n"
[ PersistLiteralEscaped (encode []) [ PersistLiteralEscaped (BSL.toStrict $ encode $ object ["a" .= False, "b" .= True])
, persistTextArray ["a"] ] , persistTextArray ["a"] ]
it "creates sane SQL (chained)" $ do it "creates sane SQL (chained)" $ do
let obj = object ["a" .= [object ["b" .= True]]] let obj = object ["a" .= [object ["b" .= True]]]
createSaneSQL @JSONValue createSaneSQL @JSONValue
(jsonbVal obj ->. "a" #-. ["0","b"]) (jsonbVal obj ->. "a" #-. ["0","b"])
"SELECT ((? -> ?) #- ?)\nFROM \"Json\"\n" "SELECT ((? -> ?) #- ?)\nFROM \"Json\"\n"
[ PersistLiteralEscaped (encode obj) [ PersistLiteralEscaped (BSL.toStrict $ encode obj)
, PersistText "a" , PersistText "a"
, persistTextArray ["0","b"] ] , persistTextArray ["0","b"] ]
it "works as expected" $ run $ do it "works as expected" $ run $ do
@ -1309,20 +1307,30 @@ fromValue act = from $ \x -> do
persistTextArray :: [T.Text] -> PersistValue persistTextArray :: [T.Text] -> PersistValue
persistTextArray = PersistArray . fmap PersistText persistTextArray = PersistArray . fmap PersistText
sqlFailWith :: (MonadCatch m, MonadIO m) => ByteString -> SqlPersistT (R.ResourceT m) a -> SqlPersistT (R.ResourceT m) () sqlFailWith :: (HasCallStack, MonadCatch m, MonadIO m, Show a) => ByteString -> SqlPersistT (R.ResourceT m) a -> SqlPersistT (R.ResourceT m) ()
sqlFailWith errState f = do sqlFailWith errState f = do
p <- (f >> return True) `catch` success eres <- try f
when p failed case eres of
where success SqlError{sqlState} Left err ->
| sqlState == errState = return False success err
| otherwise = do Right a ->
liftIO $ expectationFailure $ T.unpack $ T.concat liftIO $ expectationFailure $ mconcat
[ "should fail with: ", errStateT [ "should fail with error code: "
, ", but received: ", TE.decodeUtf8 sqlState , T.unpack errStateT
] , ", but got: "
return False , show a
failed = liftIO $ expectationFailure $ "should fail with: " `mappend` T.unpack errStateT ]
errStateT = TE.decodeUtf8 errState where
success SqlError{sqlState}
| sqlState == errState =
pure ()
| otherwise = do
liftIO $ expectationFailure $ T.unpack $ T.concat
[ "should fail with: ", errStateT
, ", but received: ", TE.decodeUtf8 sqlState
]
errStateT =
TE.decodeUtf8 errState
selectJSONwhere selectJSONwhere
:: MonadIO m :: MonadIO m