303 lines
10 KiB
Haskell
303 lines
10 KiB
Haskell
{-# LANGUAGE DataKinds #-}
|
|
{-# LANGUAGE FlexibleContexts #-}
|
|
{-# LANGUAGE FlexibleInstances #-}
|
|
{-# LANGUAGE LambdaCase #-}
|
|
{-# LANGUAGE MultiParamTypeClasses #-}
|
|
{-# LANGUAGE OverloadedStrings #-}
|
|
{-# LANGUAGE TypeFamilies #-}
|
|
{-# LANGUAGE TypeOperators #-}
|
|
{-# LANGUAGE UndecidableInstances #-}
|
|
|
|
module Servant.Server.Utils.CustomCombinatorsSpec where
|
|
|
|
import Control.Concurrent
|
|
import Data.ByteString.Builder
|
|
import Control.DeepSeq
|
|
import Control.Monad.IO.Class
|
|
import Data.ByteString as SBS hiding (map)
|
|
import Data.ByteString.Lazy as LBS hiding (map)
|
|
import Data.Monoid
|
|
import Data.Proxy
|
|
import Data.String.Conversions
|
|
import Data.Text hiding (map)
|
|
import Network.HTTP.Types
|
|
import Network.Wai
|
|
import Network.Wai.Internal
|
|
import Test.Hspec hiding (context)
|
|
|
|
import Servant.API
|
|
import Servant.Server
|
|
import Servant.Server.Utils.CustomCombinators
|
|
|
|
runApp :: Application -> Request -> IO Response
|
|
runApp app req = do
|
|
mvar <- newMVar Nothing
|
|
ResponseReceived <- app req $ \ response -> do
|
|
modifyMVar mvar $ \ Nothing ->
|
|
return $ (Just response, ResponseReceived)
|
|
modifyMVar mvar $ \mResponse -> do
|
|
case mResponse of
|
|
Nothing -> error "shouldn't happen"
|
|
Just response -> return (Just response, response)
|
|
|
|
responseBodyLbs :: Response -> IO LBS.ByteString
|
|
responseBodyLbs response = do
|
|
let (_, _, action) = responseToStream response
|
|
action $ \ streamingBody -> do
|
|
mvar <- newMVar ""
|
|
streamingBody
|
|
(\ builder -> modifyMVar_ mvar $ \ acc ->
|
|
return $ acc <> toLazyByteString builder)
|
|
(return ())
|
|
readMVar mvar
|
|
|
|
spec :: Spec
|
|
spec = do
|
|
describe "makeCaptureCombinator" $ do
|
|
it "allows to write capture combinators" $ do
|
|
let server = return
|
|
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
|
|
request = defaultRequest{
|
|
rawPathInfo = "/foo",
|
|
pathInfo = ["foo"]
|
|
}
|
|
response <- runApp app request
|
|
responseBodyLbs response `shouldReturn` "\"foo\""
|
|
|
|
it "allows to write a combinator that errors out" $ do
|
|
let server = return
|
|
app = serve (Proxy :: Proxy (StringCapture :> Get' String)) server
|
|
request = defaultRequest {
|
|
rawPathInfo = "/error",
|
|
pathInfo = ["error"]
|
|
}
|
|
response <- runApp app request
|
|
responseStatus response `shouldBe` status418
|
|
|
|
describe "makeRequestCheckCombinator" $ do
|
|
it "allows to write request check combinators" $ do
|
|
let server = return ()
|
|
app = serve (Proxy :: Proxy (CheckFooHeader :> Get' ())) server
|
|
request = defaultRequest{
|
|
requestHeaders =
|
|
("Foo", "foo") :
|
|
requestHeaders defaultRequest
|
|
}
|
|
response <- runApp app request
|
|
responseBodyLbs response `shouldReturn` "[]"
|
|
|
|
it "disallows to access the request body" $ do
|
|
let server = return ()
|
|
app = serve (Proxy :: Proxy (InvalidRequestCheckCombinator :> Get' ())) server
|
|
request = defaultRequest
|
|
runApp app request `shouldThrow`
|
|
errorCall "ERROR: makeRequestCheckCombinator: combinator must not access the request body"
|
|
|
|
describe "makeAuthCombinator" $ do
|
|
it "allows to write an auth combinator" $ do
|
|
let server (User name) = return name
|
|
app = serve (Proxy :: Proxy (AuthCombinator :> Get' String)) server
|
|
request = defaultRequest{
|
|
requestHeaders =
|
|
("Auth", "secret") :
|
|
requestHeaders defaultRequest
|
|
}
|
|
response <- runApp app request
|
|
responseStatus response `shouldBe` ok200
|
|
responseBodyLbs response `shouldReturn` "\"Alice\""
|
|
|
|
it "disallows to access the request body" $ do
|
|
let server _user = return "foo"
|
|
app = serve (Proxy :: Proxy (InvalidAuthCombinator :> Get' String)) server
|
|
request = defaultRequest
|
|
runApp app request `shouldThrow`
|
|
errorCall "ERROR: makeAuthCombinator: combinator must not access the request body"
|
|
|
|
it "allows to access the context" $ do
|
|
let server (User name) = return name
|
|
context :: Context '[ [(SBS.ByteString, User)] ]
|
|
context = [("secret", User "Bob")] :. EmptyContext
|
|
app = serveWithContext (Proxy :: Proxy (AuthWithContext :> Get' String)) context server
|
|
request = defaultRequest{
|
|
requestHeaders =
|
|
("Auth", "secret") :
|
|
requestHeaders defaultRequest
|
|
}
|
|
response <- runApp app request
|
|
responseStatus response `shouldBe` ok200
|
|
responseBodyLbs response `shouldReturn` "\"Bob\""
|
|
|
|
describe "makeCombinator" $ do
|
|
it "allows to write a combinator by providing a function (Request -> a)" $ do
|
|
let server = return
|
|
app = serve (Proxy :: Proxy (FooHeader :> Get' String)) server
|
|
request = defaultRequest{
|
|
requestHeaders =
|
|
("Foo", "foo") :
|
|
requestHeaders defaultRequest
|
|
}
|
|
response <- runApp app request
|
|
responseBodyLbs response `shouldReturn` "\"foo\""
|
|
|
|
describe "makeReqBodyCombinator" $ do
|
|
let toBody :: [IO SBS.ByteString] -> IO (IO SBS.ByteString)
|
|
toBody list = do
|
|
mvar <- newMVar list
|
|
return $ do
|
|
modifyMVar mvar $ \case
|
|
(a : r) -> do
|
|
chunk <- a
|
|
return (r, chunk)
|
|
[] -> return ([], "")
|
|
|
|
it "allows to write combinators" $ do
|
|
body <- toBody $ map return ["foo", "bar"]
|
|
let server (Source b) = liftIO $ cs <$> fromBody b
|
|
app = serve (Proxy :: Proxy (StreamRequest :> Get' String)) server
|
|
request = defaultRequest{
|
|
requestBody = body
|
|
}
|
|
response <- runApp app request
|
|
responseBodyLbs response `shouldReturn` "\"foobar\""
|
|
|
|
it "allows to stream lazily" $ do
|
|
mvar <- newEmptyMVar
|
|
body <- toBody [return "foo", takeMVar mvar >> return "bar"]
|
|
let server (Source b) = liftIO $ do
|
|
first <- b
|
|
deepseq first (return ())
|
|
putMVar mvar ()
|
|
cs <$> (first <>) <$> fromBody b
|
|
app = serve (Proxy :: Proxy (StreamRequest :> Get' String)) server
|
|
request = defaultRequest{
|
|
requestBody = body
|
|
}
|
|
response <- runApp app request
|
|
responseBodyLbs response `shouldReturn` "\"foobar\""
|
|
|
|
it "allows to implement combinators in terms of existing combinators" $ do
|
|
pending
|
|
|
|
type Get' = Get '[JSON]
|
|
|
|
-- * capture combinators
|
|
|
|
data StringCapture
|
|
|
|
instance HasServer api context => HasServer (StringCapture :> api) context where
|
|
type ServerT (StringCapture :> api) m = String -> ServerT api m
|
|
route = runCI $ makeCaptureCombinator (const getCapture)
|
|
|
|
getCapture :: Text -> IO (RouteResult String)
|
|
getCapture snippet = return $ case snippet of
|
|
"error" -> FailFatal $ ServantErr 418 "I'm a teapot" "" []
|
|
text -> Route $ cs text
|
|
|
|
-- * request check combinators
|
|
|
|
data CheckFooHeader
|
|
|
|
instance HasServer api context => HasServer (CheckFooHeader :> api) context where
|
|
type ServerT (CheckFooHeader :> api) m = ServerT api m
|
|
route = runCI $ makeRequestCheckCombinator (const checkFooHeader)
|
|
|
|
checkFooHeader :: Request -> IO (RouteResult ())
|
|
checkFooHeader request = return $
|
|
case lookup "Foo" (requestHeaders request) of
|
|
Just _ -> Route ()
|
|
Nothing -> FailFatal err400
|
|
|
|
-- | a combinator that tries to access the request body in an invalid way
|
|
data InvalidRequestCheckCombinator
|
|
|
|
instance HasServer api context => HasServer (InvalidRequestCheckCombinator :> api) context where
|
|
type ServerT (InvalidRequestCheckCombinator :> api) m = ServerT api m
|
|
route = runCI $ makeRequestCheckCombinator (const accessReqBody)
|
|
|
|
accessReqBody :: Request -> IO (RouteResult ())
|
|
accessReqBody request = do
|
|
body <- fromBody $ requestBody request
|
|
deepseq body (return $ Route ())
|
|
|
|
-- * auth combinators
|
|
|
|
data AuthCombinator
|
|
|
|
data User = User String
|
|
deriving (Eq, Show)
|
|
|
|
instance HasServer api context => HasServer (AuthCombinator :> api) context where
|
|
type ServerT (AuthCombinator :> api) m = User -> ServerT api m
|
|
route = runCI $ makeAuthCombinator (const checkAuth)
|
|
|
|
checkAuth :: Request -> IO (RouteResult User)
|
|
checkAuth request = return $ case lookup "Auth" (requestHeaders request) of
|
|
Just "secret" -> Route $ User "Alice"
|
|
Just _ -> FailFatal err401
|
|
Nothing -> FailFatal err400
|
|
|
|
-- | a combinator that tries to access the request body in an invalid way
|
|
data InvalidAuthCombinator
|
|
|
|
instance HasServer api context => HasServer (InvalidAuthCombinator :> api) context where
|
|
type ServerT (InvalidAuthCombinator :> api) m = User -> ServerT api m
|
|
route = runCI $ makeAuthCombinator (const authWithReqBody)
|
|
|
|
authWithReqBody :: Request -> IO (RouteResult User)
|
|
authWithReqBody request = do
|
|
body <- fromBody $ requestBody request
|
|
deepseq body (return $ Route $ User $ cs body)
|
|
|
|
data AuthWithContext
|
|
|
|
instance (HasContextEntry context [(SBS.ByteString, User)], HasServer api context) =>
|
|
HasServer (AuthWithContext :> api) context where
|
|
type ServerT (AuthWithContext :> api) m = User -> ServerT api m
|
|
route = runCI $ makeAuthCombinator authWithContext
|
|
|
|
-- fixme: remove foralls from haddock
|
|
|
|
authWithContext :: (HasContextEntry context [(SBS.ByteString, User)]) =>
|
|
Context context -> Request -> IO (RouteResult User)
|
|
authWithContext context request = return $ case lookup "Auth" (requestHeaders request) of
|
|
Just authToken -> case lookup authToken userDict of
|
|
Just user -> Route user
|
|
where
|
|
userDict = getContextEntry context
|
|
|
|
-- * general combinators
|
|
|
|
data FooHeader
|
|
|
|
instance HasServer api context => HasServer (FooHeader :> api) context where
|
|
type ServerT (FooHeader :> api) m = String -> ServerT api m
|
|
route = runCI $ makeCombinator (const getCustom)
|
|
|
|
getCustom :: Request -> IO (RouteResult String)
|
|
getCustom request = return $ case lookup "Foo" (requestHeaders request) of
|
|
Nothing -> FailFatal err400
|
|
Just l -> Route $ cs l
|
|
|
|
-- * streaming combinators
|
|
|
|
data StreamRequest
|
|
|
|
data Source = Source (IO SBS.ByteString)
|
|
|
|
instance HasServer api context => HasServer (StreamRequest :> api) context where
|
|
type ServerT (StreamRequest :> api) m = Source -> ServerT api m
|
|
route = runCI $ makeReqBodyCombinator (const getSource)
|
|
|
|
getSource :: IO SBS.ByteString -> Source
|
|
getSource = Source
|
|
|
|
-- * utils
|
|
|
|
fromBody :: IO SBS.ByteString -> IO SBS.ByteString
|
|
fromBody getChunk = do
|
|
chunk <- getChunk
|
|
if chunk == ""
|
|
then return ""
|
|
else do
|
|
rest <- fromBody getChunk
|
|
return $ chunk <> rest
|