Merge remote-tracking branch 'upstream/jkarni/config' into auth

This commit is contained in:
aaron levin 2015-12-27 16:42:28 +01:00
commit 0285ddf707
25 changed files with 22402 additions and 236 deletions

View file

@ -1,4 +1,4 @@
#!/bin/bash - #!/usr/bin/env bash
#=============================================================================== #===============================================================================
# #
# FILE: bump-versions.sh # FILE: bump-versions.sh

View file

@ -1,4 +1,4 @@
#!/bin/bash - #!/usr/bin/env bash
#=============================================================================== #===============================================================================
# #
# FILE: clear-sandbox.sh # FILE: clear-sandbox.sh

View file

@ -1,4 +1,4 @@
#!/bin/bash - #!/usr/bin/env bash
#=============================================================================== #===============================================================================
# #
# FILE: generate-nix-files.sh # FILE: generate-nix-files.sh

View file

@ -1,4 +1,4 @@
#!/bin/bash - #!/usr/bin/env bash
#=============================================================================== #===============================================================================
# #
# FILE: start-sandbox.sh # FILE: start-sandbox.sh

View file

@ -1,4 +1,4 @@
#!/bin/bash - #!/usr/bin/env bash
#=============================================================================== #===============================================================================
# #
# FILE: test-all.sh # FILE: test-all.sh

View file

@ -1,4 +1,4 @@
#!/bin/bash - #!/usr/bin/env bash
#=============================================================================== #===============================================================================
# #
# FILE: upload.sh # FILE: upload.sh

View file

@ -1,3 +1,4 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
@ -16,10 +17,13 @@
-- >>> type EgDefault = Get '[CSV] [(Int, String)] -- >>> type EgDefault = Get '[CSV] [(Int, String)]
module Servant.CSV.Cassava where module Servant.CSV.Cassava where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Data.Csv import Data.Csv
import Data.Proxy (Proxy (..)) import Data.Proxy (Proxy (..))
import Data.Typeable (Typeable) import Data.Typeable (Typeable)
import Data.Vector (Vector) import Data.Vector (Vector, toList)
import GHC.Generics (Generic) import GHC.Generics (Generic)
import qualified Network.HTTP.Media as M import qualified Network.HTTP.Media as M
import Servant.API (Accept (..), MimeRender (..), import Servant.API (Accept (..), MimeRender (..),
@ -50,6 +54,18 @@ instance ( DefaultOrdered a, ToNamedRecord a, EncodeOpts opt
mimeRender _ = encodeDefaultOrderedByNameWith (encodeOpts p) mimeRender _ = encodeDefaultOrderedByNameWith (encodeOpts p)
where p = Proxy :: Proxy opt where p = Proxy :: Proxy opt
-- | Encode with 'encodeByNameWith'. The 'Header' param is used for determining
-- the order of headers and fields.
instance ( ToNamedRecord a, EncodeOpts opt
) => MimeRender (CSV', opt) (Header, Vector a) where
mimeRender _ (hdr, vals) = encodeByNameWith (encodeOpts p) hdr (toList vals)
where p = Proxy :: Proxy opt
-- | Encode with 'encodeDefaultOrderedByNameWith'
instance ( DefaultOrdered a, ToNamedRecord a, EncodeOpts opt
) => MimeRender (CSV', opt) (Vector a) where
mimeRender _ = encodeDefaultOrderedByNameWith (encodeOpts p) . toList
where p = Proxy :: Proxy opt
-- ** Encode Options -- ** Encode Options
@ -66,6 +82,17 @@ instance EncodeOpts DefaultEncodeOpts where
-- ** Instances -- ** Instances
-- | Decode with 'decodeByNameWith' -- | Decode with 'decodeByNameWith'
instance ( FromNamedRecord a, DecodeOpts opt
) => MimeUnrender (CSV', opt) (Header, [a]) where
mimeUnrender _ bs = fmap toList <$> decodeByNameWith (decodeOpts p) bs
where p = Proxy :: Proxy opt
-- | Decode with 'decodeWith'. Assumes data has headers, which are stripped.
instance ( FromRecord a, DecodeOpts opt
) => MimeUnrender (CSV', opt) [a] where
mimeUnrender _ bs = toList <$> decodeWith (decodeOpts p) HasHeader bs
where p = Proxy :: Proxy opt
instance ( FromNamedRecord a, DecodeOpts opt instance ( FromNamedRecord a, DecodeOpts opt
) => MimeUnrender (CSV', opt) (Header, Vector a) where ) => MimeUnrender (CSV', opt) (Header, Vector a) where
mimeUnrender _ = decodeByNameWith (decodeOpts p) mimeUnrender _ = decodeByNameWith (decodeOpts p)

View file

@ -75,9 +75,15 @@ instance ToSample Cookie where
instance ToSample SecretData where instance ToSample SecretData where
toSamples _ = singleSample (SecretData "shhhhh!") toSamples _ = singleSample (SecretData "shhhhh!")
instance ToAuthInfo (AuthProtect Cookie User mP mE uP uE) where instance ToAuthInfo (AuthProtect "cookie-auth-lax" Cookie User mP mE uP uE) where
toAuthInfo _ = AuthenticationInfo "In this sentence we outline how authentication works." toAuthInfo _ = AuthenticationInfo "In this sentence we outline how authentication works."
"The following data is required on each request as a serialized header." ("The following data is required on each request as a serialized header."
++ "The API methods will handle authentication failures.")
instance ToAuthInfo (AuthProtect "cookie-auth-strict" Cookie User mP mE uP uE) where
toAuthInfo _ = AuthenticationInfo "In this sentence we outline how authentication works."
("The following data is required on each request as a serialized header."
++ "The handlers will handle authentication failures.")
-- We define some introductory sections, these will appear at the top of the -- We define some introductory sections, these will appear at the top of the
-- documentation. -- documentation.
@ -108,9 +114,9 @@ type TestApi =
:<|> "greet" :> Capture "greetid" Text :> Delete '[JSON] () :<|> "greet" :> Capture "greetid" Text :> Delete '[JSON] ()
-- GET /private -- GET /private
:<|> "private" :> AuthProtect Cookie User 'Strict () 'Strict () :> Get '[JSON] SecretData :<|> "private" :> AuthProtect "cookie-auth-strict" Cookie User 'Strict () 'Strict () :> Get '[JSON] SecretData
-- GET /private-lax -- GET /private-lax
:<|> "private-lax" :> AuthProtect Cookie User 'Lax () 'Lax () :> Get '[JSON] SecretData :<|> "private-lax" :> AuthProtect "cookie-auth-lax" Cookie User 'Lax () 'Lax () :> Get '[JSON] SecretData
testApi :: Proxy TestApi testApi :: Proxy TestApi
testApi = Proxy testApi = Proxy

View file

@ -727,14 +727,14 @@ instance
( HasDocs sublayout ( HasDocs sublayout
, ToSample auth , ToSample auth
, ToSample usr , ToSample usr
, ToAuthInfo (AuthProtect auth usr mPolicy mError uPolicy uError) , ToAuthInfo (AuthProtect tag auth usr mPolicy mError uPolicy uError)
) )
=> HasDocs (AuthProtect auth usr mPolicy mError uPolicy uError :> sublayout) where => HasDocs (AuthProtect tag auth usr mPolicy mError uPolicy uError :> sublayout) where
docsFor Proxy (endpoint, action) = docsFor Proxy (endpoint, action) =
docsFor (Proxy :: Proxy sublayout) (endpoint, action') docsFor (Proxy :: Proxy sublayout) (endpoint, action')
where where
authProxy = Proxy :: Proxy (AuthProtect auth usr mPolicy mError uPolicy uError) authProxy = Proxy :: Proxy (AuthProtect tag auth usr mPolicy mError uPolicy uError)
action' = over authInfo (|> toAuthInfo authProxy) action action' = over authInfo (|> toAuthInfo authProxy) action
instance instance

21836
servant-server/codex.tags Normal file

File diff suppressed because it is too large Load diff

View file

@ -59,7 +59,7 @@ server = helloH :<|> postGreetH :<|> deleteGreetH
-- Turn the server into a WAI app. 'serve' is provided by servant, -- Turn the server into a WAI app. 'serve' is provided by servant,
-- more precisely by the Servant.Server module. -- more precisely by the Servant.Server module.
test :: Application test :: Application
test = serve testApi server test = serve testApi EmptyConfig server
-- Run the server. -- Run the server.
-- --

View file

@ -37,11 +37,13 @@ library
Servant.Server Servant.Server
Servant.Server.Internal Servant.Server.Internal
Servant.Server.Internal.Authentication Servant.Server.Internal.Authentication
Servant.Server.Internal.Config
Servant.Server.Internal.Enter Servant.Server.Internal.Enter
Servant.Server.Internal.Router Servant.Server.Internal.Router
Servant.Server.Internal.RoutingApplication Servant.Server.Internal.RoutingApplication
Servant.Server.Internal.ServantErr Servant.Server.Internal.ServantErr
Servant.Utils.StaticFiles Servant.Utils.StaticFiles
build-depends: build-depends:
base >= 4.7 && < 5 base >= 4.7 && < 5
, aeson >= 0.7 && < 0.11 , aeson >= 0.7 && < 0.11
@ -49,6 +51,7 @@ library
, base64-bytestring >= 1.0.0.0 , base64-bytestring >= 1.0.0.0
, bytestring >= 0.10 && < 0.11 , bytestring >= 0.10 && < 0.11
, containers >= 0.5 && < 0.6 , containers >= 0.5 && < 0.6
, deepseq == 1.4.1.1
, http-api-data >= 0.1 && < 0.3 , http-api-data >= 0.1 && < 0.3
, http-types >= 0.8 && < 0.10 , http-types >= 0.8 && < 0.10
, network-uri >= 2.6 && < 2.7 , network-uri >= 2.6 && < 2.7
@ -69,6 +72,7 @@ library
, warp >= 3.0 && < 3.2 , warp >= 3.0 && < 3.2
, word8 >= 0.1.0 && < 0.1.3 , word8 >= 0.1.0 && < 0.1.3
, jwt , jwt
hs-source-dirs: src hs-source-dirs: src
default-language: Haskell2010 default-language: Haskell2010
ghc-options: -Wall ghc-options: -Wall
@ -96,6 +100,7 @@ test-suite spec
main-is: Spec.hs main-is: Spec.hs
other-modules: other-modules:
Servant.Server.Internal.EnterSpec Servant.Server.Internal.EnterSpec
Servant.Server.Internal.ConfigSpec
Servant.ServerSpec Servant.ServerSpec
Servant.Utils.StaticFilesSpec Servant.Utils.StaticFilesSpec
Servant.Server.ErrorSpec Servant.Server.ErrorSpec
@ -116,6 +121,7 @@ test-suite spec
, servant , servant
, servant-server , servant-server
, string-conversions , string-conversions
, should-not-typecheck == 2.*
, temporary , temporary
, text , text
, transformers , transformers

View file

@ -35,6 +35,11 @@ module Servant.Server
, generalizeNat , generalizeNat
, tweakResponse , tweakResponse
-- * Config
, ConfigEntry(..)
, Config(..)
, (.:)
-- * Default error type -- * Default error type
, ServantErr(..) , ServantErr(..)
-- ** 3XX -- ** 3XX
@ -98,14 +103,17 @@ import Servant.Server.Internal.Authentication
-- > myApi :: Proxy MyApi -- > myApi :: Proxy MyApi
-- > myApi = Proxy -- > myApi = Proxy
-- > -- >
-- > cfg :: Config '[]
-- > cfg = EmptyConfig
-- >
-- > app :: Application -- > app :: Application
-- > app = serve myApi server -- > app = serve myApi cfg server
-- > -- >
-- > main :: IO () -- > main :: IO ()
-- > main = Network.Wai.Handler.Warp.run 8080 app -- > main = Network.Wai.Handler.Warp.run 8080 app
-- --
serve :: HasServer layout => Proxy layout -> Server layout -> Application serve :: (HasServer layout, HasCfg layout a) => Proxy layout -> Config a -> Server layout -> Application
serve p server = toApplication (runRouter (route p d)) serve p cfg server = toApplication (runRouter (route p cfg d))
where where
d = Delayed r r r r (\ _ _ _ -> Route server) d = Delayed r r r r (\ _ _ _ -> Route server)
r = return (Route ()) r = return (Route ())

View file

@ -1,4 +1,5 @@
{-# LANGUAGE CPP #-} {-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-} {-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
@ -9,6 +10,7 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
#if !MIN_VERSION_base(4,8,0) #if !MIN_VERSION_base(4,8,0)
{-# LANGUAGE OverlappingInstances #-} {-# LANGUAGE OverlappingInstances #-}
#endif #endif
@ -16,6 +18,7 @@
module Servant.Server.Internal module Servant.Server.Internal
( module Servant.Server.Internal ( module Servant.Server.Internal
, module Servant.Server.Internal.Authentication , module Servant.Server.Internal.Authentication
, module Servant.Server.Internal.Config
, module Servant.Server.Internal.Router , module Servant.Server.Internal.Router
, module Servant.Server.Internal.RoutingApplication , module Servant.Server.Internal.RoutingApplication
, module Servant.Server.Internal.ServantErr , module Servant.Server.Internal.ServantErr
@ -35,6 +38,7 @@ import Data.Text (Text)
import Data.Typeable import Data.Typeable
import GHC.TypeLits (KnownSymbol, import GHC.TypeLits (KnownSymbol,
symbolVal) symbolVal)
import GHC.Exts (Constraint)
import Network.HTTP.Types hiding (Header, import Network.HTTP.Types hiding (Header,
ResponseHeaders) ResponseHeaders)
import Network.Socket (SockAddr) import Network.Socket (SockAddr)
@ -61,7 +65,9 @@ import Servant.API ((:<|>) (..), (:>),
ReqBody, Vault) ReqBody, Vault)
import Servant.API.Authentication (AuthPolicy (Strict, Lax), import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtect, AuthProtect,
AuthProtected, AuthProtectSimple,
AuthProtected(..),
AuthProtectedSimple(..),
SAuthPolicy(SLax,SStrict)) SAuthPolicy(SLax,SStrict))
import Servant.API.ContentTypes (AcceptHeader (..), import Servant.API.ContentTypes (AcceptHeader (..),
AllCTRender (..), AllCTRender (..),
@ -72,7 +78,8 @@ import Servant.API.ResponseHeaders (GetHeaders,
Headers, Headers,
getHeaders, getHeaders,
getResponse) getResponse)
import Servant.Server.Internal.Authentication (AuthData (authData)) import Servant.Server.Internal.Authentication (addAuthCheck, AuthData (authData))
import Servant.Server.Internal.Config
import Servant.Server.Internal.Router import Servant.Server.Internal.Router
import Servant.Server.Internal.RoutingApplication import Servant.Server.Internal.RoutingApplication
import Servant.Server.Internal.ServantErr import Servant.Server.Internal.ServantErr
@ -82,11 +89,13 @@ import Web.HttpApiData.Internal (parseUrlPieceMaybe, parseHeaderMaybe,
class HasServer layout where class HasServer layout where
type ServerT layout (m :: * -> *) :: * type ServerT layout (m :: * -> *) :: *
type HasCfg layout (x :: [*]) :: Constraint
route :: Proxy layout -> Delayed (Server layout) -> Router route :: HasCfg layout a => Proxy layout -> Config a -> Delayed (Server layout) -> Router
type Server layout = ServerT layout (ExceptT ServantErr IO) type Server layout = ServerT layout (ExceptT ServantErr IO)
-- * Instances -- * Instances
-- | A server for @a ':<|>' b@ first tries to match the request against the route -- | A server for @a ':<|>' b@ first tries to match the request against the route
@ -103,9 +112,10 @@ type Server layout = ServerT layout (ExceptT ServantErr IO)
instance (HasServer a, HasServer b) => HasServer (a :<|> b) where instance (HasServer a, HasServer b) => HasServer (a :<|> b) where
type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m type ServerT (a :<|> b) m = ServerT a m :<|> ServerT b m
type HasCfg (a :<|> b) x = (HasCfg a x, HasCfg b x)
route Proxy server = choice (route pa ((\ (a :<|> _) -> a) <$> server)) route Proxy cfg server = choice (route pa cfg ((\ (a :<|> _) -> a) <$> server))
(route pb ((\ (_ :<|> b) -> b) <$> server)) (route pb cfg ((\ (_ :<|> b) -> b) <$> server))
where pa = Proxy :: Proxy a where pa = Proxy :: Proxy a
pb = Proxy :: Proxy b pb = Proxy :: Proxy b
@ -135,9 +145,12 @@ instance (KnownSymbol capture, FromHttpApiData a, HasServer sublayout)
type ServerT (Capture capture a :> sublayout) m = type ServerT (Capture capture a :> sublayout) m =
a -> ServerT sublayout m a -> ServerT sublayout m
route Proxy d = type HasCfg (Capture capture a :> sublayout) x = HasCfg sublayout x
route Proxy cfg d =
DynamicRouter $ \ first -> DynamicRouter $ \ first ->
route (Proxy :: Proxy sublayout) route (Proxy :: Proxy sublayout)
cfg
(addCapture d $ case captured captureProxy first of (addCapture d $ case captured captureProxy first of
Nothing -> return $ Fail err404 Nothing -> return $ Fail err404
Just v -> return $ Route v Just v -> return $ Route v
@ -236,7 +249,9 @@ instance
type ServerT (Delete ctypes a) m = m a type ServerT (Delete ctypes a) m = m a
route Proxy = methodRouter methodDelete (Proxy :: Proxy ctypes) ok200 type HasCfg (Delete ctypes a) x = ()
route Proxy _ = methodRouter methodDelete (Proxy :: Proxy ctypes) ok200
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
@ -246,7 +261,9 @@ instance
type ServerT (Delete ctypes ()) m = m () type ServerT (Delete ctypes ()) m = m ()
route Proxy = methodRouterEmpty methodDelete type HasCfg (Delete ctypes ()) x = ()
route Proxy _ = methodRouterEmpty methodDelete
-- Add response headers -- Add response headers
instance instance
@ -258,63 +275,137 @@ instance
type ServerT (Delete ctypes (Headers h v)) m = m (Headers h v) type ServerT (Delete ctypes (Headers h v)) m = m (Headers h v)
route Proxy = methodRouterHeaders methodDelete (Proxy :: Proxy ctypes) ok200 type HasCfg (Delete ctypes (Headers h v)) x = ()
route Proxy _ = methodRouterHeaders methodDelete (Proxy :: Proxy ctypes) ok200
-- | Simple Authentication instance
instance
#if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-}
#endif
( HasServer sublayout
) => HasServer (AuthProtectSimple tag usr :> sublayout) where
type ServerT (AuthProtectSimple tag usr :> sublayout) m =
usr -> ServerT sublayout m
type HasCfg (AuthProtectSimple tag usr :> sublayout) x =
( HasConfigEntry x tag (AuthProtectedSimple Request ServantErr usr)
, HasCfg sublayout x
)
route _ cfg subserver =
let authProtection :: AuthProtectedSimple Request ServantErr usr
authProtection = getConfigEntry (Proxy :: Proxy tag) cfg
handler = fmap (either FailFatal Route) . authHandler authProtection
in WithRequest $ \ request ->
route (Proxy :: Proxy sublayout)
cfg
(addAuth subserver (handler request))
-- | Authentication in Missing x Unauth = Strict x Strict mode -- | Authentication in Missing x Unauth = Strict x Strict mode
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-} {-# OVERLAPPABLE #-}
#endif #endif
(AuthData authData mError, HasServer sublayout) => HasServer (AuthProtect authData (usr :: *) 'Strict (mError :: *) 'Strict (uError :: *) :> sublayout) where ( AuthData authData mError
type ServerT (AuthProtect authData usr 'Strict mError 'Strict uError :> sublayout) m = , HasServer sublayout
AuthProtected IO ServantErr 'Strict mError 'Strict uError authData usr (usr -> ServerT sublayout m) ) => HasServer (AuthProtect tag authData (usr :: *) 'Strict (mError :: *) 'Strict (uError :: *) :> sublayout) where
route _ subserver = WithRequest $ \ request -> type ServerT (AuthProtect tag authData usr 'Strict mError 'Strict uError :> sublayout) m =
route (Proxy :: Proxy sublayout) (addAuthCheck SStrict SStrict subserver (authCheck request)) usr -> ServerT sublayout m
where
authCheck req = pure . Route $ authData req type HasCfg (AuthProtect tag authData usr 'Strict mError 'Strict uError :> sublayout) a =
( HasConfigEntry a tag (AuthProtected IO ServantErr 'Strict mError 'Strict uError authData usr)
, HasCfg sublayout a
)
route _ cfg subserver =
let authProtection :: (AuthProtected IO ServantErr 'Strict mError 'Strict uError authData usr )
authProtection = getConfigEntry (Proxy :: Proxy tag) cfg
extractAuthData req = pure . Route $ authData req
in WithRequest $ \ request ->
route (Proxy :: Proxy sublayout)
cfg
(addAuthCheck SStrict SStrict authProtection subserver (extractAuthData request))
-- | Authentication in Missing x Unauth = Strict x Lax mode -- | Authentication in Missing x Unauth = Strict x Lax mode
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-} {-# OVERLAPPABLE #-}
#endif #endif
(AuthData authData mError, HasServer sublayout) => HasServer (AuthProtect authData (usr :: *) 'Strict (mError :: *) 'Lax (uError :: *) :> sublayout) where ( AuthData authData mError
type ServerT (AuthProtect authData usr 'Strict mError 'Lax uError :> sublayout) m = , HasServer sublayout
AuthProtected IO ServantErr 'Strict mError 'Lax uError authData usr (Either uError usr -> ServerT sublayout m) ) => HasServer (AuthProtect tag authData (usr :: *) 'Strict (mError :: *) 'Lax (uError :: *) :> sublayout) where
route _ subserver = WithRequest $ \ request -> type ServerT (AuthProtect tag authData usr 'Strict mError 'Lax uError :> sublayout) m =
route (Proxy :: Proxy sublayout) (addAuthCheck SStrict SLax subserver (authCheck request)) Either uError usr -> ServerT sublayout m
where
authCheck req = pure . Route $ authData req type HasCfg (AuthProtect tag authData usr 'Strict mError 'Lax uError :> sublayout) a =
( HasConfigEntry a tag (AuthProtected IO ServantErr 'Strict mError 'Lax uError authData usr)
, HasCfg sublayout a
)
route _ cfg subserver =
let authProtection :: (AuthProtected IO ServantErr 'Strict mError 'Lax uError authData usr )
authProtection = getConfigEntry (Proxy :: Proxy tag) cfg
extractAuthData req = pure . Route $ authData req
in WithRequest $ \ request ->
route (Proxy :: Proxy sublayout)
cfg
(addAuthCheck SStrict SLax authProtection subserver (extractAuthData request))
-- | Authentication in Missing x Unauth = Lax x Strict mode -- | Authentication in Missing x Unauth = Lax x Strict mode
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-} {-# OVERLAPPABLE #-}
#endif #endif
(AuthData authData mError, HasServer sublayout) => HasServer (AuthProtect authData (usr :: *) 'Lax (mError :: *) 'Strict (uError :: *) :> sublayout) where ( AuthData authData mError
type ServerT (AuthProtect authData usr 'Lax mError 'Strict uError :> sublayout) m = , HasServer sublayout
AuthProtected IO ServantErr 'Lax mError 'Strict uError authData usr (Either mError usr -> ServerT sublayout m) ) => HasServer (AuthProtect tag authData (usr :: *) 'Lax (mError :: *) 'Strict (uError :: *) :> sublayout) where
route _ subserver = WithRequest $ \ request -> type ServerT (AuthProtect tag authData usr 'Lax mError 'Strict uError :> sublayout) m =
route (Proxy :: Proxy sublayout) (addAuthCheck SLax SStrict subserver (authCheck request)) Either mError usr -> ServerT sublayout m
where
authCheck req = pure . Route $ authData req type HasCfg (AuthProtect tag authData usr 'Lax mError 'Strict uError :> sublayout) a =
( HasConfigEntry a tag (AuthProtected IO ServantErr 'Lax mError 'Strict uError authData usr)
, HasCfg sublayout a
)
route _ cfg subserver =
let authProtection :: (AuthProtected IO ServantErr 'Lax mError 'Strict uError authData usr )
authProtection = getConfigEntry (Proxy :: Proxy tag) cfg
extractAuthData req = pure . Route $ authData req
in WithRequest $ \ request ->
route (Proxy :: Proxy sublayout)
cfg
(addAuthCheck SLax SStrict authProtection subserver (extractAuthData request))
-- | Authentication in Missing x Unauth = Lax x Lax mode -- | Authentication in Missing x Unauth = Lax x Lax mode
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-} {-# OVERLAPPABLE #-}
#endif #endif
(AuthData authData mError, HasServer sublayout) => HasServer (AuthProtect authData (usr :: *) 'Lax (mError :: *) 'Lax (uError :: *) :> sublayout) where ( AuthData authData mError
type ServerT (AuthProtect authData usr 'Lax mError 'Lax uError :> sublayout) m = , HasServer sublayout
AuthProtected IO ServantErr 'Lax mError 'Lax uError authData usr (Either (Either mError uError) usr -> ServerT sublayout m) ) => HasServer (AuthProtect tag authData (usr :: *) 'Lax (mError :: *) 'Lax (uError :: *) :> sublayout) where
type ServerT (AuthProtect tag authData usr 'Lax mError 'Lax uError :> sublayout) m =
Either (Either mError uError) usr -> ServerT sublayout m
route _ subserver = WithRequest $ \ request -> type HasCfg (AuthProtect tag authData usr 'Lax mError 'Lax uError :> sublayout) a =
route (Proxy :: Proxy sublayout) (addAuthCheck SLax SLax subserver (authCheck request)) ( HasConfigEntry a tag (AuthProtected IO ServantErr 'Lax mError 'Lax uError authData usr)
where , HasCfg sublayout a
authCheck req = pure . Route $ authData req )
route _ cfg subserver =
let authProtection :: (AuthProtected IO ServantErr 'Lax mError 'Lax uError authData usr )
authProtection = getConfigEntry (Proxy :: Proxy tag) cfg
extractAuthData req = pure . Route $ authData req
in WithRequest $ \ request ->
route (Proxy :: Proxy sublayout)
cfg
(addAuthCheck SLax SLax authProtection subserver (extractAuthData request))
-- | When implementing the handler for a 'Get' endpoint, -- | When implementing the handler for a 'Get' endpoint,
-- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post' -- just like for 'Servant.API.Delete.Delete', 'Servant.API.Post.Post'
@ -337,7 +428,9 @@ instance
type ServerT (Get ctypes a) m = m a type ServerT (Get ctypes a) m = m a
route Proxy = methodRouter methodGet (Proxy :: Proxy ctypes) ok200 type HasCfg (Get ctypes a) x = ()
route Proxy _ = methodRouter methodGet (Proxy :: Proxy ctypes) ok200
-- '()' ==> 204 No Content -- '()' ==> 204 No Content
instance instance
@ -348,7 +441,9 @@ instance
type ServerT (Get ctypes ()) m = m () type ServerT (Get ctypes ()) m = m ()
route Proxy = methodRouterEmpty methodGet type HasCfg (Get ctypes ()) a = ()
route Proxy _ = methodRouterEmpty methodGet
-- Add response headers -- Add response headers
instance instance
@ -360,7 +455,9 @@ instance
type ServerT (Get ctypes (Headers h v)) m = m (Headers h v) type ServerT (Get ctypes (Headers h v)) m = m (Headers h v)
route Proxy = methodRouterHeaders methodGet (Proxy :: Proxy ctypes) ok200 type HasCfg (Get ctypes (Headers h v)) a = ()
route Proxy _ = methodRouterHeaders methodGet (Proxy :: Proxy ctypes) ok200
-- | If you use 'Header' in one of the endpoints for your API, -- | If you use 'Header' in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function -- this automatically requires your server-side handler to be a function
@ -388,9 +485,11 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout)
type ServerT (Header sym a :> sublayout) m = type ServerT (Header sym a :> sublayout) m =
Maybe a -> ServerT sublayout m Maybe a -> ServerT sublayout m
route Proxy subserver = WithRequest $ \ request -> type HasCfg (Header sym a :> sublayout) x = HasCfg sublayout x
route Proxy cfg subserver = WithRequest $ \ request ->
let mheader = parseHeaderMaybe =<< lookup str (requestHeaders request) let mheader = parseHeaderMaybe =<< lookup str (requestHeaders request)
in route (Proxy :: Proxy sublayout) (passToServer subserver mheader) in route (Proxy :: Proxy sublayout) cfg (passToServer subserver mheader)
where str = fromString $ symbolVal (Proxy :: Proxy sym) where str = fromString $ symbolVal (Proxy :: Proxy sym)
-- | When implementing the handler for a 'Post' endpoint, -- | When implementing the handler for a 'Post' endpoint,
@ -415,7 +514,9 @@ instance
type ServerT (Post ctypes a) m = m a type ServerT (Post ctypes a) m = m a
route Proxy = methodRouter methodPost (Proxy :: Proxy ctypes) created201 type HasCfg (Post ctypes a) x = ()
route Proxy _ = methodRouter methodPost (Proxy :: Proxy ctypes) created201
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
@ -425,7 +526,9 @@ instance
type ServerT (Post ctypes ()) m = m () type ServerT (Post ctypes ()) m = m ()
route Proxy = methodRouterEmpty methodPost type HasCfg (Post ctypes ()) x = ()
route Proxy _ = methodRouterEmpty methodPost
-- Add response headers -- Add response headers
instance instance
@ -437,7 +540,9 @@ instance
type ServerT (Post ctypes (Headers h v)) m = m (Headers h v) type ServerT (Post ctypes (Headers h v)) m = m (Headers h v)
route Proxy = methodRouterHeaders methodPost (Proxy :: Proxy ctypes) created201 type HasCfg (Post ctypes (Headers h v)) x = ()
route Proxy _ = methodRouterHeaders methodPost (Proxy :: Proxy ctypes) created201
-- | When implementing the handler for a 'Put' endpoint, -- | When implementing the handler for a 'Put' endpoint,
-- just like for 'Servant.API.Delete.Delete', 'Servant.API.Get.Get' -- just like for 'Servant.API.Delete.Delete', 'Servant.API.Get.Get'
@ -460,7 +565,9 @@ instance
type ServerT (Put ctypes a) m = m a type ServerT (Put ctypes a) m = m a
route Proxy = methodRouter methodPut (Proxy :: Proxy ctypes) ok200 type HasCfg (Put ctypes a) x = ()
route Proxy _ = methodRouter methodPut (Proxy :: Proxy ctypes) ok200
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
@ -470,7 +577,9 @@ instance
type ServerT (Put ctypes ()) m = m () type ServerT (Put ctypes ()) m = m ()
route Proxy = methodRouterEmpty methodPut type HasCfg (Put ctypes ()) x = ()
route Proxy _ = methodRouterEmpty methodPut
-- Add response headers -- Add response headers
instance instance
@ -482,7 +591,9 @@ instance
type ServerT (Put ctypes (Headers h v)) m = m (Headers h v) type ServerT (Put ctypes (Headers h v)) m = m (Headers h v)
route Proxy = methodRouterHeaders methodPut (Proxy :: Proxy ctypes) ok200 type HasCfg (Put ctypes (Headers h v)) x = ()
route Proxy _ = methodRouterHeaders methodPut (Proxy :: Proxy ctypes) ok200
-- | When implementing the handler for a 'Patch' endpoint, -- | When implementing the handler for a 'Patch' endpoint,
-- just like for 'Servant.API.Delete.Delete', 'Servant.API.Get.Get' -- just like for 'Servant.API.Delete.Delete', 'Servant.API.Get.Get'
@ -503,7 +614,9 @@ instance
type ServerT (Patch ctypes a) m = m a type ServerT (Patch ctypes a) m = m a
route Proxy = methodRouter methodPatch (Proxy :: Proxy ctypes) ok200 type HasCfg (Patch ctypes a) x = ()
route Proxy _ = methodRouter methodPatch (Proxy :: Proxy ctypes) ok200
instance instance
#if MIN_VERSION_base(4,8,0) #if MIN_VERSION_base(4,8,0)
@ -513,7 +626,9 @@ instance
type ServerT (Patch ctypes ()) m = m () type ServerT (Patch ctypes ()) m = m ()
route Proxy = methodRouterEmpty methodPatch type HasCfg (Patch ctypes ()) x = ()
route Proxy _ = methodRouterEmpty methodPatch
-- Add response headers -- Add response headers
instance instance
@ -525,7 +640,9 @@ instance
type ServerT (Patch ctypes (Headers h v)) m = m (Headers h v) type ServerT (Patch ctypes (Headers h v)) m = m (Headers h v)
route Proxy = methodRouterHeaders methodPatch (Proxy :: Proxy ctypes) ok200 type HasCfg (Patch ctypes (Headers h v)) x = ()
route Proxy _ = methodRouterHeaders methodPatch (Proxy :: Proxy ctypes) ok200
-- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API, -- | If you use @'QueryParam' "author" Text@ in one of the endpoints for your API,
-- this automatically requires your server-side handler to be a function -- this automatically requires your server-side handler to be a function
@ -554,7 +671,9 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout)
type ServerT (QueryParam sym a :> sublayout) m = type ServerT (QueryParam sym a :> sublayout) m =
Maybe a -> ServerT sublayout m Maybe a -> ServerT sublayout m
route Proxy subserver = WithRequest $ \ request -> type HasCfg (QueryParam sym a :> sublayout) x = HasCfg sublayout x
route Proxy cfg subserver = WithRequest $ \ request ->
let querytext = parseQueryText $ rawQueryString request let querytext = parseQueryText $ rawQueryString request
param = param =
case lookup paramname querytext of case lookup paramname querytext of
@ -562,7 +681,7 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout)
Just Nothing -> Nothing -- param present with no value -> Nothing Just Nothing -> Nothing -- param present with no value -> Nothing
Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to Just (Just v) -> parseQueryParamMaybe v -- if present, we try to convert to
-- the right type -- the right type
in route (Proxy :: Proxy sublayout) (passToServer subserver param) in route (Proxy :: Proxy sublayout) cfg (passToServer subserver param)
where paramname = cs $ symbolVal (Proxy :: Proxy sym) where paramname = cs $ symbolVal (Proxy :: Proxy sym)
-- | If you use @'QueryParams' "authors" Text@ in one of the endpoints for your API, -- | If you use @'QueryParams' "authors" Text@ in one of the endpoints for your API,
@ -590,14 +709,17 @@ instance (KnownSymbol sym, FromHttpApiData a, HasServer sublayout)
type ServerT (QueryParams sym a :> sublayout) m = type ServerT (QueryParams sym a :> sublayout) m =
[a] -> ServerT sublayout m [a] -> ServerT sublayout m
route Proxy subserver = WithRequest $ \ request -> type HasCfg (QueryParams sym a :> sublayout) x =
HasCfg sublayout x
route Proxy cfg subserver = WithRequest $ \ request ->
let querytext = parseQueryText $ rawQueryString request let querytext = parseQueryText $ rawQueryString request
-- if sym is "foo", we look for query string parameters -- if sym is "foo", we look for query string parameters
-- named "foo" or "foo[]" and call parseQueryParam on the -- named "foo" or "foo[]" and call parseQueryParam on the
-- corresponding values -- corresponding values
parameters = filter looksLikeParam querytext parameters = filter looksLikeParam querytext
values = mapMaybe (convert . snd) parameters values = mapMaybe (convert . snd) parameters
in route (Proxy :: Proxy sublayout) (passToServer subserver values) in route (Proxy :: Proxy sublayout) cfg (passToServer subserver values)
where paramname = cs $ symbolVal (Proxy :: Proxy sym) where paramname = cs $ symbolVal (Proxy :: Proxy sym)
looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]") looksLikeParam (name, _) = name == paramname || name == (paramname <> "[]")
convert Nothing = Nothing convert Nothing = Nothing
@ -621,13 +743,16 @@ instance (KnownSymbol sym, HasServer sublayout)
type ServerT (QueryFlag sym :> sublayout) m = type ServerT (QueryFlag sym :> sublayout) m =
Bool -> ServerT sublayout m Bool -> ServerT sublayout m
route Proxy subserver = WithRequest $ \ request -> type HasCfg (QueryFlag sym :> sublayout) a =
HasCfg sublayout a
route Proxy cfg subserver = WithRequest $ \ request ->
let querytext = parseQueryText $ rawQueryString request let querytext = parseQueryText $ rawQueryString request
param = case lookup paramname querytext of param = case lookup paramname querytext of
Just Nothing -> True -- param is there, with no value Just Nothing -> True -- param is there, with no value
Just (Just v) -> examine v -- param with a value Just (Just v) -> examine v -- param with a value
Nothing -> False -- param not in the query string Nothing -> False -- param not in the query string
in route (Proxy :: Proxy sublayout) (passToServer subserver param) in route (Proxy :: Proxy sublayout) cfg (passToServer subserver param)
where paramname = cs $ symbolVal (Proxy :: Proxy sym) where paramname = cs $ symbolVal (Proxy :: Proxy sym)
examine v | v == "true" || v == "1" || v == "" = True examine v | v == "true" || v == "1" || v == "" = True
| otherwise = False | otherwise = False
@ -644,7 +769,9 @@ instance HasServer Raw where
type ServerT Raw m = Application type ServerT Raw m = Application
route Proxy rawApplication = LeafRouter $ \ request respond -> do type HasCfg Raw x = ()
route Proxy _ rawApplication = LeafRouter $ \ request respond -> do
r <- runDelayed rawApplication r <- runDelayed rawApplication
case r of case r of
Route app -> app request (respond . Route) Route app -> app request (respond . Route)
@ -678,8 +805,11 @@ instance ( AllCTUnrender list a, HasServer sublayout
type ServerT (ReqBody list a :> sublayout) m = type ServerT (ReqBody list a :> sublayout) m =
a -> ServerT sublayout m a -> ServerT sublayout m
route Proxy subserver = WithRequest $ \ request -> type HasCfg (ReqBody list a :> sublayout) x =
route (Proxy :: Proxy sublayout) (addBodyCheck subserver (bodyCheck request)) HasCfg sublayout x
route Proxy cfg subserver = WithRequest $ \ request ->
route (Proxy :: Proxy sublayout) cfg (addBodyCheck subserver (bodyCheck request))
where where
bodyCheck request = do bodyCheck request = do
-- See HTTP RFC 2616, section 7.2.1 -- See HTTP RFC 2616, section 7.2.1
@ -701,36 +831,46 @@ instance (KnownSymbol path, HasServer sublayout) => HasServer (path :> sublayout
type ServerT (path :> sublayout) m = ServerT sublayout m type ServerT (path :> sublayout) m = ServerT sublayout m
route Proxy subserver = StaticRouter $ type HasCfg (path :> sublayout) x = HasCfg sublayout x
route Proxy cfg subserver = StaticRouter $
M.singleton (cs (symbolVal proxyPath)) M.singleton (cs (symbolVal proxyPath))
(route (Proxy :: Proxy sublayout) subserver) (route (Proxy :: Proxy sublayout) cfg subserver)
where proxyPath = Proxy :: Proxy path where proxyPath = Proxy :: Proxy path
instance HasServer api => HasServer (RemoteHost :> api) where instance HasServer api => HasServer (RemoteHost :> api) where
type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m type ServerT (RemoteHost :> api) m = SockAddr -> ServerT api m
route Proxy subserver = WithRequest $ \req -> type HasCfg (RemoteHost :> api) a = HasCfg api a
route (Proxy :: Proxy api) (passToServer subserver $ remoteHost req)
route Proxy cfg subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) cfg (passToServer subserver $ remoteHost req)
instance HasServer api => HasServer (IsSecure :> api) where instance HasServer api => HasServer (IsSecure :> api) where
type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m type ServerT (IsSecure :> api) m = IsSecure -> ServerT api m
route Proxy subserver = WithRequest $ \req -> type HasCfg (IsSecure :> api) a = HasCfg api a
route (Proxy :: Proxy api) (passToServer subserver $ secure req)
route Proxy cfg subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) cfg (passToServer subserver $ secure req)
where secure req = if isSecure req then Secure else NotSecure where secure req = if isSecure req then Secure else NotSecure
instance HasServer api => HasServer (Vault :> api) where instance HasServer api => HasServer (Vault :> api) where
type ServerT (Vault :> api) m = Vault -> ServerT api m type ServerT (Vault :> api) m = Vault -> ServerT api m
route Proxy subserver = WithRequest $ \req -> type HasCfg (Vault :> api) a = HasCfg api a
route (Proxy :: Proxy api) (passToServer subserver $ vault req)
route Proxy cfg subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) cfg (passToServer subserver $ vault req)
instance HasServer api => HasServer (HttpVersion :> api) where instance HasServer api => HasServer (HttpVersion :> api) where
type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m type ServerT (HttpVersion :> api) m = HttpVersion -> ServerT api m
route Proxy subserver = WithRequest $ \req -> type HasCfg (HttpVersion :> api) a = HasCfg api a
route (Proxy :: Proxy api) (passToServer subserver $ httpVersion req)
route Proxy cfg subserver = WithRequest $ \req ->
route (Proxy :: Proxy api) cfg (passToServer subserver $ httpVersion req)
pathIsEmpty :: Request -> Bool pathIsEmpty :: Request -> Bool
pathIsEmpty = go . pathInfo pathIsEmpty = go . pathInfo

View file

@ -8,8 +8,14 @@
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
module Servant.Server.Internal.Authentication module Servant.Server.Internal.Authentication
( AuthData (..) ( addAuthCheck
, addAuthCheckSS
, addAuthCheckSL
, addAuthCheckLS
, addAuthCheckLL
, AuthData (..)
, authProtect , authProtect
, authProtectSimple
, basicAuthLax , basicAuthLax
, basicAuthStrict , basicAuthStrict
, jwtAuthStrict , jwtAuthStrict
@ -33,10 +39,13 @@ import Network.Wai (Request, requestHeaders)
import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders)) import Servant.Server.Internal.ServantErr (err401, ServantErr(errHeaders))
import Servant.API.Authentication (AuthPolicy (Strict, Lax), import Servant.API.Authentication (AuthPolicy (Strict, Lax),
AuthProtected(..), AuthProtected(..),
AuthProtectedSimple (..),
BasicAuth (BasicAuth), BasicAuth (BasicAuth),
JWTAuth(..), JWTAuth(..),
OnMissing (..), OnMissing (..),
OnUnauthenticated (..)) OnUnauthenticated (..),
SAuthPolicy(..))
import Servant.Server.Internal.RoutingApplication
import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret) import Web.JWT (decodeAndVerifySignature, JWT, VerifiedJWT, Secret)
import qualified Web.JWT as JWT (decode) import qualified Web.JWT as JWT (decode)
@ -49,10 +58,14 @@ class AuthData a e | a -> e where
authProtect :: OnMissing IO ServantErr missingPolicy missingError authProtect :: OnMissing IO ServantErr missingPolicy missingError
-> OnUnauthenticated IO ServantErr unauthPolicy unauthError authData -> OnUnauthenticated IO ServantErr unauthPolicy unauthError authData
-> (authData -> IO (Either unauthError usr)) -> (authData -> IO (Either unauthError usr))
-> subserver -> AuthProtected IO ServantErr missingPolicy missingError unauthPolicy unauthError authData usr
-> AuthProtected IO ServantErr missingPolicy missingError unauthPolicy unauthError authData usr subserver
authProtect = AuthProtected authProtect = AuthProtected
-- | combinator to create authentication protected servers.
authProtectSimple :: (Request -> IO (Either ServantErr u))
-> AuthProtectedSimple Request ServantErr u
authProtectSimple = AuthProtectedSimple
-- | 'BasicAuth' instance for authData -- | 'BasicAuth' instance for authData
instance AuthData (BasicAuth realm) () where instance AuthData (BasicAuth realm) () where
authData request = maybe (Left ()) Right $ do authData request = maybe (Left ()) Right $ do
@ -85,24 +98,22 @@ basicUnauthenticatedHandler :: forall realm. KnownSymbol realm
basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ basicAuthFailure p)) basicUnauthenticatedHandler p = StrictUnauthenticated (const . const (return $ basicAuthFailure p))
-- | Basic authentication combinator with strict failure. -- | Basic authentication combinator with strict failure.
basicAuthStrict :: forall realm usr subserver. KnownSymbol realm basicAuthStrict :: forall realm usr. KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr)) => (BasicAuth realm -> IO (Maybe usr))
-> subserver -> AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth realm) usr
-> AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth realm) usr subserver basicAuthStrict check =
basicAuthStrict check sub =
let mHandler = basicMissingHandler (Proxy :: Proxy realm) let mHandler = basicMissingHandler (Proxy :: Proxy realm)
unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm) unauthHandler = basicUnauthenticatedHandler (Proxy :: Proxy realm)
check' = \auth -> maybe (Left ()) Right <$> check auth check' = \auth -> maybe (Left ()) Right <$> check auth
in AuthProtected mHandler unauthHandler check' sub in AuthProtected mHandler unauthHandler check'
-- | Basic authentication combinator with lax failure. -- | Basic authentication combinator with lax failure.
basicAuthLax :: KnownSymbol realm basicAuthLax :: KnownSymbol realm
=> (BasicAuth realm -> IO (Maybe usr)) => (BasicAuth realm -> IO (Maybe usr))
-> subserver -> AuthProtected IO ServantErr 'Lax () 'Lax () (BasicAuth realm) usr
-> AuthProtected IO ServantErr 'Lax () 'Lax () (BasicAuth realm) usr subserver basicAuthLax check =
basicAuthLax check sub =
let check' = \a -> maybe (Left ()) Right <$> check a let check' = \a -> maybe (Left ()) Right <$> check a
in AuthProtected LaxMissing LaxUnauthenticated check' sub in AuthProtected LaxMissing LaxUnauthenticated check'
-- | Authentication data we extract from requests for JWT-based authentication. -- | Authentication data we extract from requests for JWT-based authentication.
instance AuthData JWTAuth () where instance AuthData JWTAuth () where
@ -118,14 +129,100 @@ jwtWithError e = err401 { errHeaders = [("WWW-Authenticate", "Bearer error=\""<>
-- | OnMissing handler for Strict, JWT-based authentication -- | OnMissing handler for Strict, JWT-based authentication
jwtAuthStrict :: Secret jwtAuthStrict :: Secret
-> subserver -> AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT)
-> AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT) subserver jwtAuthStrict secret =
jwtAuthStrict secret sub =
let missingHandler = StrictMissing (const $ return (jwtWithError "invalid_request")) let missingHandler = StrictMissing (const $ return (jwtWithError "invalid_request"))
unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token")) unauthHandler = StrictUnauthenticated (const . const (return $ jwtWithError "invalid_token"))
check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth check = return . maybe (Left ()) Right . decodeAndVerifySignature secret . unJWTAuth
in AuthProtected missingHandler unauthHandler check sub in AuthProtected missingHandler unauthHandler check
-- | A type alias to make simple authentication endpoints -- | A type alias to make simple authentication endpoints
type SimpleAuthProtected mPolicy uPolicy authData usr subserver = type SimpleAuthProtected mPolicy uPolicy authData usr =
AuthProtected IO ServantErr mPolicy () uPolicy () authData usr subserver AuthProtected IO ServantErr mPolicy () uPolicy () authData usr
-------------------------------------------------------------------------------
-- Helpers
-------------------------------------------------------------------------------
-- | helper type family to capture server handled values for various policies
type family AuthDelayedReturn (mP :: AuthPolicy) mE (uP :: AuthPolicy) uE usr :: * where
AuthDelayedReturn 'Strict mE 'Strict uE usr = usr
AuthDelayedReturn 'Strict mE 'Lax uE usr = Either uE usr
AuthDelayedReturn 'Lax mE 'Strict uE usr = Either mE usr
AuthDelayedReturn 'Lax mE 'Lax uE usr = Either (Either mE uE) usr
-- | Internal method to generate auth checkers for various policies. Scary type signature
-- but it does help with understanding the logic of how each policy works. See
-- examples below.
genAuthCheck :: (OnMissing IO ServantErr mP mE -> mE -> IO (RouteResult (AuthDelayedReturn mP mE uP uE usr)))
-> (OnUnauthenticated IO ServantErr uP uE auth -> uE -> auth -> IO (RouteResult (AuthDelayedReturn mP mE uP uE usr)))
-> (usr -> (AuthDelayedReturn mP mE uP uE usr))
-> AuthProtected IO ServantErr mP mE uP uE auth usr
-> Delayed (AuthDelayedReturn mP mE uP uE usr -> a)
-> IO (RouteResult (Either mE auth))
-> Delayed a
genAuthCheck missingHandler unauthHandler returnHandler authProtection d new =
let newAuth =
new `bindRouteResults` \ eAuthData ->
case eAuthData of
-- we failed to extract authentication data from the request
Left mError -> missingHandler (onMissing authProtection) mError
-- auth data was succesfully extracted from the request
Right aData -> do
eUsr <- checkAuth authProtection aData
case eUsr of
-- we failed to authenticate the user
Left uError -> unauthHandler (onUnauthenticated authProtection) uError aData
-- user was authenticated
Right usr ->
(return . Route . returnHandler) usr
in addAuth d newAuth
-- | Delayed auth checker for Strict Missing and Strict Unauthentication
addAuthCheckSS :: AuthProtected IO ServantErr 'Strict mError 'Strict uError auth usr
-> Delayed (usr -> a)
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckSS = genAuthCheck (\(StrictMissing handler) e -> FailFatal <$> handler e)
(\(StrictUnauthenticated handler) e a -> FailFatal <$> handler e a)
id
-- | Delayed auth checker for Strict Missing and Lax Unauthentication
addAuthCheckSL :: AuthProtected IO ServantErr 'Strict mError 'Lax uError auth usr
-> Delayed (Either uError usr -> a)
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckSL = genAuthCheck (\(StrictMissing handler) e -> FailFatal <$> handler e)
(\(LaxUnauthenticated) e _ -> (return . Route . Left) e)
Right
-- | Delayed auth checker for Lax Missing and Strict Unauthentication
addAuthCheckLS :: AuthProtected IO ServantErr 'Lax mError 'Strict uError auth usr
-> Delayed (Either mError usr -> a)
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLS = genAuthCheck (\(LaxMissing) e -> (return . Route . Left) e)
(\(StrictUnauthenticated handler) e a -> FailFatal <$> handler e a)
Right
-- | Delayed auth checker for Lax Missing and Lax Unauthentication
addAuthCheckLL :: AuthProtected IO ServantErr 'Lax mError 'Lax uError auth usr
-> Delayed (Either (Either mError uError) usr -> a)
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLL = genAuthCheck (\(LaxMissing) e -> (return . Route . Left . Left) e)
(\(LaxUnauthenticated) e _ -> (return . Route . Left . Right) e)
Right
-- | Add an auth check by supplying OnMissing policies and OnUnauthenticated policies.
addAuthCheck :: SAuthPolicy mPolicy
-> SAuthPolicy uPolicy
-> AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr
-> Delayed (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a)
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheck SStrict SStrict = addAuthCheckSS
addAuthCheck SStrict SLax = addAuthCheckSL
addAuthCheck SLax SStrict = addAuthCheckLS
addAuthCheck SLax SLax = addAuthCheckLL

View file

@ -0,0 +1,76 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
#if !MIN_VERSION_base(4,8,0)
{-# LANGUAGE OverlappingInstances #-}
#endif
module Servant.Server.Internal.Config where
import Control.DeepSeq (NFData(rnf))
import GHC.Generics (Generic)
import Data.Typeable (Typeable)
-- | A single entry in the configuration. The first parameter is phantom, and
-- is used to lookup a @ConfigEntry@ in a @Config@.
newtype ConfigEntry tag a = ConfigEntry { unConfigEntry :: a }
deriving ( Eq, Show, Read, Enum, Integral, Fractional, Generic, Typeable
, Num, Ord, Real, Functor, Foldable, Traversable, NFData)
instance Applicative (ConfigEntry tag) where
pure = ConfigEntry
ConfigEntry f <*> ConfigEntry a = ConfigEntry $ f a
instance Monad (ConfigEntry tag) where
return = ConfigEntry
ConfigEntry a >>= f = f a
-- | The entire configuration.
data Config a where
EmptyConfig :: Config '[]
ConsConfig :: x -> Config xs -> Config (x ': xs)
instance Eq (Config '[]) where
_ == _ = True
instance (Eq a, Eq (Config as)) => Eq (Config (a ' : as)) where
ConsConfig x1 y1 == ConsConfig x2 y2 = x1 == x2 && y1 == y2
instance NFData (Config '[]) where
rnf EmptyConfig = ()
instance (NFData a, NFData (Config as)) => NFData (Config (a ': as)) where
rnf (x `ConsConfig` ys) = rnf x `seq` rnf ys
(.:) :: x -> Config xs -> Config (ConfigEntry tag x ': xs)
e .: cfg = ConsConfig (ConfigEntry e) cfg
infixr 4 .:
class HasConfigEntry (cfg :: [*]) (a :: k) (val :: *) | cfg a -> val where
getConfigEntry :: proxy a -> Config cfg -> val
instance
#if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-}
#endif
HasConfigEntry xs tag val => HasConfigEntry (notIt ': xs) tag val where
getConfigEntry p (ConsConfig _ xs) = getConfigEntry p xs
instance
#if MIN_VERSION_base(4,8,0)
{-# OVERLAPPABLE #-}
#endif
HasConfigEntry (ConfigEntry tag val ': xs) tag val where
getConfigEntry _ (ConsConfig x _) = unConfigEntry x

View file

@ -28,9 +28,6 @@ import qualified Control.Monad.Writer.Strict as SWriter
import Data.Typeable import Data.Typeable
import Servant.API import Servant.API
import Servant.API.Authentication
-- import Servant.Server.Internal.Authentication (AuthProtected (AuthProtectedStrict, AuthProtectedLax))
class Enter typ arg ret | typ arg -> ret, typ ret -> arg where class Enter typ arg ret | typ arg -> ret, typ ret -> arg where
enter :: arg -> typ -> ret enter :: arg -> typ -> ret
@ -99,10 +96,3 @@ squashNat = Nat squash
-- | Like @mmorph@'s `generalize`. -- | Like @mmorph@'s `generalize`.
generalizeNat :: Applicative m => Identity :~> m generalizeNat :: Applicative m => Identity :~> m
generalizeNat = Nat (pure . runIdentity) generalizeNat = Nat (pure . runIdentity)
-- | 'Enter' instance for AuthProtected
instance Enter subserver arg ret => Enter (AuthProtected m e mP mE uP uE authData usr subserver)
arg
(AuthProtected m e mP mE uP uE authData usr ret)
where
enter arg (AuthProtected mHandler uHandler check sub) = AuthProtected mHandler uHandler check (enter arg sub)

View file

@ -21,8 +21,6 @@ import Network.Wai (Application, Request,
Response, ResponseReceived, Response, ResponseReceived,
requestBody, requestBody,
strictRequestBody) strictRequestBody)
import Servant.API.Authentication (AuthProtected (..), AuthPolicy(Strict,Lax),
OnMissing (..), OnUnauthenticated (..), SAuthPolicy (..))
import Servant.Server.Internal.ServantErr import Servant.Server.Internal.ServantErr
type RoutingApplication = type RoutingApplication =
@ -180,83 +178,12 @@ addMethodCheck :: Delayed a
addMethodCheck (Delayed captures method auth body server) new = addMethodCheck (Delayed captures method auth body server) new =
Delayed captures (combineRouteResults const method new) auth body server Delayed captures (combineRouteResults const method new) auth body server
-- | helper type family to capture server handled values for various policies -- | Add authentication
type family AuthDelayedReturn (mP :: AuthPolicy) mE (uP :: AuthPolicy) uE usr :: * where addAuth :: Delayed (a -> b)
AuthDelayedReturn 'Strict mE 'Strict uE usr = usr -> IO (RouteResult a)
AuthDelayedReturn 'Strict mE 'Lax uE usr = Either uE usr -> Delayed b
AuthDelayedReturn 'Lax mE 'Strict uE usr = Either mE usr addAuth (Delayed captures method auth body server) new =
AuthDelayedReturn 'Lax mE 'Lax uE usr = Either (Either mE uE) usr Delayed captures method (combineRouteResults (,) auth new) body (\ x (y,v) z -> ($ v) <$> server x y z)
-- | Internal method to generate auth checkers for various policies. Scary type signature
-- but it does help with understanding the logic of how each policy works. See
-- examples below.
genAuthCheck :: (OnMissing IO ServantErr mPolicy mError -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> mError -> IO (RouteResult a))
-> (OnUnauthenticated IO ServantErr uPolicy uError auth -> (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a) -> uError -> auth -> IO (RouteResult a))
-> (usr -> (AuthDelayedReturn mPolicy mError uPolicy uError usr))
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
genAuthCheck missingHandler unauthHandler returnHandler d@(Delayed captures method _ body _) new =
let newAuth =
runDelayed d `bindRouteResults` \ authProtection ->
new `bindRouteResults` \ eAuthData ->
case eAuthData of
-- we failed to extract authentication data from the request
Left mError -> missingHandler (onMissing authProtection) (subserver authProtection) mError
-- auth data was succesfully extracted from the request
Right aData -> do
eUsr <- checkAuth authProtection aData
case eUsr of
-- we failed to authenticate the user
Left uError -> unauthHandler (onUnauthenticated authProtection) (subserver authProtection) uError aData
-- user was authenticated
Right usr ->
(return . Route . subserver authProtection) (returnHandler usr)
in Delayed captures method newAuth body (\_ y _ -> Route y)
-- | Delayed auth checker for Strict Missing and Strict Unauthentication
addAuthCheckSS :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Strict uError auth usr (usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckSS = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
id
-- | Delayed auth checker for Strict Missing and Lax Unauthentication
addAuthCheckSL :: Delayed (AuthProtected IO ServantErr 'Strict mError 'Lax uError auth usr (Either uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckSL = genAuthCheck (\(StrictMissing handler) _ e -> FailFatal <$> handler e)
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left e))
Right
-- | Delayed auth checker for Lax Missing and Strict Unauthentication
addAuthCheckLS :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Strict uError auth usr (Either mError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLS = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left e))
(\(StrictUnauthenticated handler) _ e a -> FailFatal <$> handler e a)
Right
-- | Delayed auth checker for Lax Missing and Lax Unauthentication
addAuthCheckLL :: Delayed (AuthProtected IO ServantErr 'Lax mError 'Lax uError auth usr (Either (Either mError uError) usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheckLL = genAuthCheck (\(LaxMissing) cont e -> (return . Route . cont) (Left (Left e)))
(\(LaxUnauthenticated) cont e _ -> (return . Route . cont) (Left (Right e)))
Right
-- | Add an auth check by supplying OnMissing policies and OnUnauthenticated policies.
addAuthCheck :: SAuthPolicy mPolicy
-> SAuthPolicy uPolicy
-> Delayed (AuthProtected IO ServantErr mPolicy mError uPolicy uError auth usr (AuthDelayedReturn mPolicy mError uPolicy uError usr -> a))
-> IO (RouteResult (Either mError auth))
-> Delayed a
addAuthCheck SStrict SStrict = addAuthCheckSS
addAuthCheck SStrict SLax = addAuthCheckSL
addAuthCheck SLax SStrict = addAuthCheckLS
addAuthCheck SLax SLax = addAuthCheckLL
-- | Add a body check to the end of the body block. -- | Add a body check to the end of the body block.
addBodyCheck :: Delayed (a -> b) addBodyCheck :: Delayed (a -> b)

View file

@ -42,7 +42,7 @@ errorOrderServer = \_ _ -> throwE err402
errorOrderSpec :: Spec errorOrderSpec :: Spec
errorOrderSpec = describe "HTTP error order" errorOrderSpec = describe "HTTP error order"
$ with (return $ serve errorOrderApi errorOrderServer) $ do $ with (return $ serve errorOrderApi EmptyConfig errorOrderServer) $ do
let badContentType = (hContentType, "text/plain") let badContentType = (hContentType, "text/plain")
badAccept = (hAccept, "text/plain") badAccept = (hAccept, "text/plain")
badMethod = methodGet badMethod = methodGet
@ -89,7 +89,7 @@ prioErrorsApi = Proxy
prioErrorsSpec :: Spec prioErrorsSpec :: Spec
prioErrorsSpec = describe "PrioErrors" $ do prioErrorsSpec = describe "PrioErrors" $ do
let server = return let server = return
with (return $ serve prioErrorsApi server) $ do with (return $ serve prioErrorsApi EmptyConfig server) $ do
let check (mdescr, method) path (cdescr, ctype, body) resp = let check (mdescr, method) path (cdescr, ctype, body) resp =
it fulldescr $ it fulldescr $
Test.Hspec.Wai.request method path [(hContentType, ctype)] body Test.Hspec.Wai.request method path [(hContentType, ctype)] body
@ -154,7 +154,7 @@ errorRetryServer
errorRetrySpec :: Spec errorRetrySpec :: Spec
errorRetrySpec = describe "Handler search" errorRetrySpec = describe "Handler search"
$ with (return $ serve errorRetryApi errorRetryServer) $ do $ with (return $ serve errorRetryApi EmptyConfig errorRetryServer) $ do
let jsonCT = (hContentType, "application/json") let jsonCT = (hContentType, "application/json")
jsonAccept = (hAccept, "application/json") jsonAccept = (hAccept, "application/json")
@ -194,7 +194,7 @@ errorChoiceServer = return 0
errorChoiceSpec :: Spec errorChoiceSpec :: Spec
errorChoiceSpec = describe "Multiple handlers return errors" errorChoiceSpec = describe "Multiple handlers return errors"
$ with (return $ serve errorChoiceApi errorChoiceServer) $ do $ with (return $ serve errorChoiceApi EmptyConfig errorChoiceServer) $ do
it "should respond with 404 if no path matches" $ do it "should respond with 404 if no path matches" $ do
request methodGet "" [] "" `shouldRespondWith` 404 request methodGet "" [] "" `shouldRespondWith` 404

View file

@ -0,0 +1,37 @@
{-# LANGUAGE DataKinds #-}
{-# OPTIONS_GHC -fdefer-type-errors #-}
module Servant.Server.Internal.ConfigSpec (spec) where
import Data.Proxy (Proxy (..))
import Test.Hspec (Spec, describe, it, shouldBe)
import Test.ShouldNotTypecheck (shouldNotTypecheck)
import Servant.Server.Internal.Config
spec :: Spec
spec = do
getConfigEntrySpec
getConfigEntrySpec :: Spec
getConfigEntrySpec = describe "getConfigEntry" $ do
let cfg1 = 0 .: EmptyConfig :: Config '[ConfigEntry "a" Int]
cfg2 = 1 .: cfg1 :: Config '[ConfigEntry "a" Int, ConfigEntry "a" Int]
it "gets the config if a matching one exists" $ do
getConfigEntry (Proxy :: Proxy "a") cfg1 `shouldBe` 0
it "gets the first matching config" $ do
getConfigEntry (Proxy :: Proxy "a") cfg2 `shouldBe` 1
it "does not typecheck if key does not exist" $ do
let x = getConfigEntry (Proxy :: Proxy "b") cfg1 :: Int
shouldNotTypecheck x
it "does not typecheck if key maps to a different type" $ do
let x = getConfigEntry (Proxy :: Proxy "a") cfg1 :: String
shouldNotTypecheck x

View file

@ -48,12 +48,12 @@ combinedReaderServer = enter fReader combinedReaderServer'
enterSpec :: Spec enterSpec :: Spec
enterSpec = describe "Enter" $ do enterSpec = describe "Enter" $ do
with (return (serve readerAPI readerServer)) $ do with (return (serve readerAPI EmptyConfig readerServer)) $ do
it "allows running arbitrary monads" $ do it "allows running arbitrary monads" $ do
get "int" `shouldRespondWith` "1797" get "int" `shouldRespondWith` "1797"
post "string" "3" `shouldRespondWith` "\"hi\""{ matchStatus = 201 } post "string" "3" `shouldRespondWith` "\"hi\""{ matchStatus = 201 }
with (return (serve combinedAPI combinedReaderServer)) $ do with (return (serve combinedAPI EmptyConfig combinedReaderServer)) $ do
it "allows combnation of enters" $ do it "allows combnation of enters" $ do
get "bool" `shouldRespondWith` "true" get "bool" `shouldRespondWith` "true"

View file

@ -3,6 +3,7 @@
{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-} {-# LANGUAGE TypeSynonymInstances #-}
@ -36,7 +37,7 @@ import Network.HTTP.Types (hAccept, hContentType,
ok200, parseQuery, ResponseHeaders, Status(..)) ok200, parseQuery, ResponseHeaders, Status(..))
import Network.Wai (Application, Request, pathInfo, import Network.Wai (Application, Request, pathInfo,
queryString, rawQueryString, queryString, rawQueryString,
responseLBS, responseBuilder) responseBuilder, responseLBS)
import Network.Wai.Internal (Response (ResponseBuilder)) import Network.Wai.Internal (Response (ResponseBuilder))
import Network.Wai.Test (defaultRequest, request, import Network.Wai.Test (defaultRequest, request,
runSession, simpleBody, simpleHeaders, SResponse) runSession, simpleBody, simpleHeaders, SResponse)
@ -47,7 +48,7 @@ import Servant.API ((:<|>) (..), (:>), Capture, Delete,
QueryFlag, QueryParam, QueryParams, QueryFlag, QueryParam, QueryParams,
Raw, RemoteHost, ReqBody, Raw, RemoteHost, ReqBody,
addHeader) addHeader)
import Servant.Server (Server, serve, ServantErr(..), err404) import Servant.Server ((.:), ConfigEntry, Config(EmptyConfig), Server, serve, ServantErr(..), err404)
import Test.Hspec (Spec, describe, it, shouldBe, shouldContain) import Test.Hspec (Spec, describe, it, shouldBe, shouldContain)
import Test.Hspec.Wai (get, liftIO, matchHeaders, import Test.Hspec.Wai (get, liftIO, matchHeaders,
matchStatus, post, request, matchStatus, post, request,
@ -62,7 +63,6 @@ import Servant.Server.Internal.Authentication
import Servant.Server.Internal.RoutingApplication (RouteResult(Route)) import Servant.Server.Internal.RoutingApplication (RouteResult(Route))
import Web.JWT hiding (JSON) import Web.JWT hiding (JSON)
-- * test data types -- * test data types
data Person = Person { data Person = Person {
@ -126,7 +126,7 @@ captureServer legs = case legs of
captureSpec :: Spec captureSpec :: Spec
captureSpec = do captureSpec = do
describe "Servant.API.Capture" $ do describe "Servant.API.Capture" $ do
with (return (serve captureApi captureServer)) $ do with (return (serve captureApi EmptyConfig captureServer)) $ do
it "can capture parts of the 'pathInfo'" $ do it "can capture parts of the 'pathInfo'" $ do
response <- get "/2" response <- get "/2"
@ -137,6 +137,7 @@ captureSpec = do
with (return (serve with (return (serve
(Proxy :: Proxy (Capture "captured" String :> Raw)) (Proxy :: Proxy (Capture "captured" String :> Raw))
EmptyConfig
(\ "captured" request_ respond -> (\ "captured" request_ respond ->
respond $ responseLBS ok200 [] (cs $ show $ pathInfo request_)))) $ do respond $ responseLBS ok200 [] (cs $ show $ pathInfo request_)))) $ do
it "strips the captured path snippet from pathInfo" $ do it "strips the captured path snippet from pathInfo" $ do
@ -153,7 +154,7 @@ getSpec :: Spec
getSpec = do getSpec = do
describe "Servant.API.Get" $ do describe "Servant.API.Get" $ do
let server = return alice :<|> return () :<|> return () let server = return alice :<|> return () :<|> return ()
with (return $ serve getApi server) $ do with (return $ serve getApi EmptyConfig server) $ do
it "allows to GET a Person" $ do it "allows to GET a Person" $ do
response <- get "/" response <- get "/"
@ -176,7 +177,7 @@ headSpec :: Spec
headSpec = do headSpec = do
describe "Servant.API.Head" $ do describe "Servant.API.Head" $ do
let server = return alice :<|> return () :<|> return () let server = return alice :<|> return () :<|> return ()
with (return $ serve getApi server) $ do with (return $ serve getApi EmptyConfig server) $ do
it "allows to GET a Person" $ do it "allows to GET a Person" $ do
response <- Test.Hspec.Wai.request methodHead "/" [] "" response <- Test.Hspec.Wai.request methodHead "/" [] ""
@ -223,7 +224,7 @@ queryParamSpec :: Spec
queryParamSpec = do queryParamSpec = do
describe "Servant.API.QueryParam" $ do describe "Servant.API.QueryParam" $ do
it "allows to retrieve simple GET parameters" $ it "allows to retrieve simple GET parameters" $
(flip runSession) (serve queryParamApi qpServer) $ do (flip runSession) (serve queryParamApi EmptyConfig qpServer) $ do
let params1 = "?name=bob" let params1 = "?name=bob"
response1 <- Network.Wai.Test.request defaultRequest{ response1 <- Network.Wai.Test.request defaultRequest{
rawQueryString = params1, rawQueryString = params1,
@ -235,7 +236,7 @@ queryParamSpec = do
} }
it "allows to retrieve lists in GET parameters" $ it "allows to retrieve lists in GET parameters" $
(flip runSession) (serve queryParamApi qpServer) $ do (flip runSession) (serve queryParamApi EmptyConfig qpServer) $ do
let params2 = "?names[]=bob&names[]=john" let params2 = "?names[]=bob&names[]=john"
response2 <- Network.Wai.Test.request defaultRequest{ response2 <- Network.Wai.Test.request defaultRequest{
rawQueryString = params2, rawQueryString = params2,
@ -249,7 +250,7 @@ queryParamSpec = do
it "allows to retrieve value-less GET parameters" $ it "allows to retrieve value-less GET parameters" $
(flip runSession) (serve queryParamApi qpServer) $ do (flip runSession) (serve queryParamApi EmptyConfig qpServer) $ do
let params3 = "?capitalize" let params3 = "?capitalize"
response3 <- Network.Wai.Test.request defaultRequest{ response3 <- Network.Wai.Test.request defaultRequest{
rawQueryString = params3, rawQueryString = params3,
@ -295,7 +296,7 @@ postSpec :: Spec
postSpec = do postSpec = do
describe "Servant.API.Post and .ReqBody" $ do describe "Servant.API.Post and .ReqBody" $ do
let server = return . age :<|> return . age :<|> return () let server = return . age :<|> return . age :<|> return ()
with (return $ serve postApi server) $ do with (return $ serve postApi EmptyConfig server) $ do
let post' x = Test.Hspec.Wai.request methodPost x [(hContentType let post' x = Test.Hspec.Wai.request methodPost x [(hContentType
, "application/json;charset=utf-8")] , "application/json;charset=utf-8")]
@ -337,7 +338,7 @@ putSpec :: Spec
putSpec = do putSpec = do
describe "Servant.API.Put and .ReqBody" $ do describe "Servant.API.Put and .ReqBody" $ do
let server = return . age :<|> return . age :<|> return () let server = return . age :<|> return . age :<|> return ()
with (return $ serve putApi server) $ do with (return $ serve putApi EmptyConfig server) $ do
let put' x = Test.Hspec.Wai.request methodPut x [(hContentType let put' x = Test.Hspec.Wai.request methodPut x [(hContentType
, "application/json;charset=utf-8")] , "application/json;charset=utf-8")]
@ -379,7 +380,7 @@ patchSpec :: Spec
patchSpec = do patchSpec = do
describe "Servant.API.Patch and .ReqBody" $ do describe "Servant.API.Patch and .ReqBody" $ do
let server = return . age :<|> return . age :<|> return () let server = return . age :<|> return . age :<|> return ()
with (return $ serve patchApi server) $ do with (return $ serve patchApi EmptyConfig server) $ do
let patch' x = Test.Hspec.Wai.request methodPatch x [(hContentType let patch' x = Test.Hspec.Wai.request methodPatch x [(hContentType
, "application/json;charset=utf-8")] , "application/json;charset=utf-8")]
@ -424,13 +425,13 @@ headerSpec = describe "Servant.API.Header" $ do
expectsString (Just x) = when (x /= "more from you") $ error "Expected more from you" expectsString (Just x) = when (x /= "more from you") $ error "Expected more from you"
expectsString Nothing = error "Expected a string" expectsString Nothing = error "Expected a string"
with (return (serve headerApi expectsInt)) $ do with (return (serve headerApi EmptyConfig expectsInt)) $ do
let delete' x = Test.Hspec.Wai.request methodDelete x [("MyHeader" ,"5")] let delete' x = Test.Hspec.Wai.request methodDelete x [("MyHeader" ,"5")]
it "passes the header to the handler (Int)" $ it "passes the header to the handler (Int)" $
delete' "/" "" `shouldRespondWith` 204 delete' "/" "" `shouldRespondWith` 204
with (return (serve headerApi expectsString)) $ do with (return (serve headerApi EmptyConfig expectsString)) $ do
let delete' x = Test.Hspec.Wai.request methodDelete x [("MyHeader" ,"more from you")] let delete' x = Test.Hspec.Wai.request methodDelete x [("MyHeader" ,"more from you")]
it "passes the header to the handler (String)" $ it "passes the header to the handler (String)" $
@ -447,7 +448,7 @@ rawSpec :: Spec
rawSpec = do rawSpec = do
describe "Servant.API.Raw" $ do describe "Servant.API.Raw" $ do
it "runs applications" $ do it "runs applications" $ do
(flip runSession) (serve rawApi (rawApplication (const (42 :: Integer)))) $ do (flip runSession) (serve rawApi EmptyConfig (rawApplication (const (42 :: Integer)))) $ do
response <- Network.Wai.Test.request defaultRequest{ response <- Network.Wai.Test.request defaultRequest{
pathInfo = ["foo"] pathInfo = ["foo"]
} }
@ -455,7 +456,7 @@ rawSpec = do
simpleBody response `shouldBe` "42" simpleBody response `shouldBe` "42"
it "gets the pathInfo modified" $ do it "gets the pathInfo modified" $ do
(flip runSession) (serve rawApi (rawApplication pathInfo)) $ do (flip runSession) (serve rawApi EmptyConfig (rawApplication pathInfo)) $ do
response <- Network.Wai.Test.request defaultRequest{ response <- Network.Wai.Test.request defaultRequest{
pathInfo = ["foo", "bar"] pathInfo = ["foo", "bar"]
} }
@ -485,7 +486,7 @@ unionServer =
unionSpec :: Spec unionSpec :: Spec
unionSpec = do unionSpec = do
describe "Servant.API.Alternative" $ do describe "Servant.API.Alternative" $ do
with (return $ serve unionApi unionServer) $ do with (return $ serve unionApi EmptyConfig unionServer) $ do
it "unions endpoints" $ do it "unions endpoints" $ do
response <- get "/foo" response <- get "/foo"
@ -517,7 +518,7 @@ responseHeadersServer = let h = return $ addHeader 5 $ addHeader "kilroy" "hi"
responseHeadersSpec :: Spec responseHeadersSpec :: Spec
responseHeadersSpec = describe "ResponseHeaders" $ do responseHeadersSpec = describe "ResponseHeaders" $ do
with (return $ serve (Proxy :: Proxy ResponseHeadersApi) responseHeadersServer) $ do with (return $ serve (Proxy :: Proxy ResponseHeadersApi) EmptyConfig responseHeadersServer) $ do
let methods = [(methodGet, 200), (methodPost, 201), (methodPut, 200), (methodPatch, 200)] let methods = [(methodGet, 200), (methodPost, 201), (methodPut, 200), (methodPatch, 200)]
@ -571,7 +572,7 @@ prioErrorsApi = Proxy
prioErrorsSpec :: Spec prioErrorsSpec :: Spec
prioErrorsSpec = describe "PrioErrors" $ do prioErrorsSpec = describe "PrioErrors" $ do
let server = return . age let server = return . age
with (return $ serve prioErrorsApi server) $ do with (return $ serve prioErrorsApi EmptyConfig server) $ do
let check (mdescr, method) path (cdescr, ctype, body) resp = let check (mdescr, method) path (cdescr, ctype, body) resp =
it fulldescr $ it fulldescr $
Test.Hspec.Wai.request method path [(hContentType, ctype)] body Test.Hspec.Wai.request method path [(hContentType, ctype)] body
@ -625,7 +626,7 @@ miscServ = versionHandler
hostHandler = return . show hostHandler = return . show
miscReqCombinatorsSpec :: Spec miscReqCombinatorsSpec :: Spec
miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $ miscReqCombinatorsSpec = with (return $ serve miscApi EmptyConfig miscServ) $
describe "Misc. combinators for request inspection" $ do describe "Misc. combinators for request inspection" $ do
it "Successfully gets the HTTP version specified in the request" $ it "Successfully gets the HTTP version specified in the request" $
go "/version" "\"HTTP/1.0\"" go "/version" "\"HTTP/1.0\""
@ -642,8 +643,8 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $
-- | we include two endpoints /foo and /bar and we put the BasicAuth -- | we include two endpoints /foo and /bar and we put the BasicAuth
-- portion in two different places -- portion in two different places
type AuthUser = ByteString type AuthUser = ByteString
type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict () 'Strict () type BasicAuthFooRealm = AuthProtect "foo" (BasicAuth "foo-realm") AuthUser 'Strict () 'Strict ()
type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict () 'Strict () type BasicAuthBarRealm = AuthProtect "bar" (BasicAuth "bar-realm") AuthUser 'Strict () 'Strict ()
type BasicAuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person type BasicAuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person
:<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal :<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal
@ -656,12 +657,13 @@ basicAuthBarCheck :: BasicAuth "bar-realm" -> IO (Maybe AuthUser)
basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar" basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar"
then return (Just "bar") then return (Just "bar")
else return Nothing else return Nothing
basicBasicAuthRequiredApi :: Proxy BasicAuthRequiredAPI basicBasicAuthRequiredApi :: Proxy BasicAuthRequiredAPI
basicBasicAuthRequiredApi = Proxy basicBasicAuthRequiredApi = Proxy
basicAuthRequiredServer :: Server BasicAuthRequiredAPI basicAuthRequiredServer :: Server BasicAuthRequiredAPI
basicAuthRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice) basicAuthRequiredServer = (const . return $ alice)
:<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry) :<|> (const . return $ jerry)
-- base64-encoded "servant:server" -- base64-encoded "servant:server"
base64ServantColonServer :: ByteString base64ServantColonServer :: ByteString
@ -681,7 +683,11 @@ basicAuthGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("A
basicAuthRequiredSpec :: Spec basicAuthRequiredSpec :: Spec
basicAuthRequiredSpec = do basicAuthRequiredSpec = do
describe "Servant.API.Authentication" $ do describe "Servant.API.Authentication" $ do
with (return $ serve basicBasicAuthRequiredApi basicAuthRequiredServer) $ do let fooAuthProtect = basicAuthStrict basicAuthFooCheck
barAuthProtect = basicAuthStrict basicAuthBarCheck
config :: Config [ConfigEntry "foo" (AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth "foo-realm") AuthUser), ConfigEntry "bar" (AuthProtected IO ServantErr 'Strict () 'Strict () (BasicAuth "bar-realm") AuthUser)]
config = fooAuthProtect .: barAuthProtect .: EmptyConfig
with (return $ serve basicBasicAuthRequiredApi config basicAuthRequiredServer) $ do
it "allows access with the correct username and password" $ do it "allows access with the correct username and password" $ do
response1 <- basicAuthGet "/foo" base64ServantColonServer response1 <- basicAuthGet "/foo" base64ServantColonServer
liftIO $ do liftIO $ do
@ -709,7 +715,7 @@ basicAuthRequiredSpec = do
(simpleHeaders bar401) `shouldContain` barHeader (simpleHeaders bar401) `shouldContain` barHeader
type JWTAuthProtect = AuthProtect JWTAuth (JWT VerifiedJWT) 'Strict () 'Strict () type JWTAuthProtect = AuthProtect "jwt" JWTAuth (JWT VerifiedJWT) 'Strict () 'Strict ()
type JWTAuthRequiredAPI = JWTAuthProtect :> "foo" :> Get '[JSON] Person type JWTAuthRequiredAPI = JWTAuthProtect :> "foo" :> Get '[JSON] Person
@ -718,8 +724,9 @@ jwtAuthRequiredApi :: Proxy JWTAuthRequiredAPI
jwtAuthRequiredApi = Proxy jwtAuthRequiredApi = Proxy
jwtSecret = secret "secret" jwtSecret = secret "secret"
jwtAuthRequiredServer :: Server JWTAuthRequiredAPI jwtAuthRequiredServer :: Server JWTAuthRequiredAPI
jwtAuthRequiredServer = jwtAuthStrict jwtSecret (const . return $ alice) jwtAuthRequiredServer = const . return $ alice
correctToken = encodeUtf8 $ encodeSigned HS256 jwtSecret def correctToken = encodeUtf8 $ encodeSigned HS256 jwtSecret def
corruptToken = "blah" corruptToken = "blah"
@ -731,7 +738,10 @@ jwtAuthGet path token = Test.Hspec.Wai.request methodGet path [("Authorization",
jwtAuthRequiredSpec :: Spec jwtAuthRequiredSpec :: Spec
jwtAuthRequiredSpec = do jwtAuthRequiredSpec = do
describe "JWT Auth" $ do describe "JWT Auth" $ do
with (return $ serve jwtAuthRequiredApi jwtAuthRequiredServer) $ do let jwtAuthProtect = jwtAuthStrict jwtSecret
config :: Config '[ConfigEntry "jwt" (AuthProtected IO ServantErr 'Strict () 'Strict () JWTAuth (JWT VerifiedJWT))]
config = jwtAuthProtect .: EmptyConfig
with (return $ serve jwtAuthRequiredApi config jwtAuthRequiredServer) $ do
it "allows access with the correct token" $ do it "allows access with the correct token" $ do
response <- jwtAuthGet "/foo" correctToken response <- jwtAuthGet "/foo" correctToken
liftIO $ do liftIO $ do

View file

@ -21,7 +21,7 @@ import Servant.API.Capture (Capture)
import Servant.API.Get (Get) import Servant.API.Get (Get)
import Servant.API.Raw (Raw) import Servant.API.Raw (Raw)
import Servant.API.Sub ((:>)) import Servant.API.Sub ((:>))
import Servant.Server (Server, serve) import Servant.Server (Server, serve, Config(EmptyConfig))
import Servant.ServerSpec (Person (Person)) import Servant.ServerSpec (Person (Person))
import Servant.Utils.StaticFiles (serveDirectory) import Servant.Utils.StaticFiles (serveDirectory)
@ -34,7 +34,7 @@ api :: Proxy Api
api = Proxy api = Proxy
app :: Application app :: Application
app = serve api server app = serve api EmptyConfig server
server :: Server Api server :: Server Api
server = server =

View file

@ -8,7 +8,9 @@
module Servant.API.Authentication module Servant.API.Authentication
( AuthPolicy (..) ( AuthPolicy (..)
, AuthProtect , AuthProtect
, AuthProtectSimple
, AuthProtected (..) , AuthProtected (..)
, AuthProtectedSimple (..)
, BasicAuth (..) , BasicAuth (..)
, JWTAuth (..) , JWTAuth (..)
, OnMissing (..) , OnMissing (..)
@ -35,7 +37,9 @@ data SAuthPolicy (p :: AuthPolicy) where
SLax :: SAuthPolicy 'Lax SLax :: SAuthPolicy 'Lax
-- | the combinator to be used in API types -- | the combinator to be used in API types
data AuthProtect authData usr (missingPolicy :: AuthPolicy) missingError (unauthPolicy :: AuthPolicy) unauthError data AuthProtect (tag :: k) authData usr (missingPolicy :: AuthPolicy) missingError (unauthPolicy :: AuthPolicy) unauthError
data AuthProtectSimple (tag :: k) (usr :: *)
-- | A GADT indexed by policy strictness that encompasses the ways -- | A GADT indexed by policy strictness that encompasses the ways
-- users will handle the case where authentication data is missing -- users will handle the case where authentication data is missing
@ -80,13 +84,14 @@ data OnUnauthenticated m responseError (policy :: AuthPolicy) errorIndex authDat
-- authData: the type of authData present in a request (e.g. JWT token) -- authData: the type of authData present in a request (e.g. JWT token)
-- usr: a data type extracted from the authenticated data. This data is likely fetched from a database. -- usr: a data type extracted from the authenticated data. This data is likely fetched from a database.
-- subserver: the rest of the servant API. -- subserver: the rest of the servant API.
data AuthProtected m rError (mPolicy :: AuthPolicy) mError (uPolicy :: AuthPolicy) uError authData usr subserver = data AuthProtected m rError (mPolicy :: AuthPolicy) mError (uPolicy :: AuthPolicy) uError authData usr =
AuthProtected { onMissing :: OnMissing m rError mPolicy mError AuthProtected { onMissing :: OnMissing m rError mPolicy mError
, onUnauthenticated :: OnUnauthenticated m rError uPolicy uError authData , onUnauthenticated :: OnUnauthenticated m rError uPolicy uError authData
, checkAuth :: authData -> m (Either uError usr) , checkAuth :: authData -> m (Either uError usr)
, subserver :: subserver
} }
newtype AuthProtectedSimple req e usr = AuthProtectedSimple { authHandler :: req -> IO (Either e usr) }
-- | Basic Authentication with respect to a specified @realm@ and a @lookup@ -- | Basic Authentication with respect to a specified @realm@ and a @lookup@
-- type to encapsulate authentication logic. -- type to encapsulate authentication logic.
data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString data BasicAuth (realm :: Symbol) = BasicAuth { baUser :: ByteString

View file

@ -17,4 +17,5 @@ extra-deps:
- engine-io-wai-1.0.2 - engine-io-wai-1.0.2
- control-monad-omega-0.3.1 - control-monad-omega-0.3.1
- jwt-0.6.0 - jwt-0.6.0
- should-not-typecheck-2.0.1
resolver: nightly-2015-10-08 resolver: nightly-2015-10-08