diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index 8cc305a4..cc28819e 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -124,6 +124,7 @@ test-suite spec , wai , wai-extra , warp + , jwt test-suite doctests build-depends: base diff --git a/servant-server/src/Servant/Server/Internal/Authentication.hs b/servant-server/src/Servant/Server/Internal/Authentication.hs index 216403e4..7f4765e6 100644 --- a/servant-server/src/Servant/Server/Internal/Authentication.hs +++ b/servant-server/src/Servant/Server/Internal/Authentication.hs @@ -17,7 +17,7 @@ module Servant.Server.Internal.Authentication , jwtAuthStrict ) where -import Control.Monad (guard, (<=<)) +import Control.Monad (guard) import qualified Data.ByteString as B import Data.ByteString.Base64 (decodeLenient) #if !MIN_VERSION_base(4,8,0) @@ -30,7 +30,7 @@ import Data.String (fromString) import Data.Word8 (isSpace, toLower, _colon) import GHC.TypeLits (KnownSymbol, symbolVal) import Data.Text.Encoding (decodeUtf8) -import Data.Text (splitOn, Text) +import Data.Text (splitOn) import Network.HTTP.Types.Status (status401) import Network.Wai (Request, Response, requestHeaders, responseBuilder) @@ -38,8 +38,10 @@ import Servant.API.Authentication (AuthPolicy (Strict, Lax), AuthProtected, BasicAuth (BasicAuth), JWTAuth (..)) + JWTAuth) +import Servant.Server.Internal.RoutingApplication (RouteResult(FailFatal)) -import Web.JWT (JWT, UnverifiedJWT, VerifiedJWT, Secret, JSON) +import Web.JWT (JWT, VerifiedJWT, Secret) import qualified Web.JWT as JWT (decode, decodeAndVerifySignature, secret) -- | Class to represent the ability to extract authentication-related @@ -50,10 +52,10 @@ class AuthData a where -- | handlers to deal with authentication failures. data AuthHandlers authData = AuthHandlers { -- we couldn't find the right type of auth data (or any, for that matter) - onMissingAuthData :: IO Response + onMissingAuthData :: IO ServantError , -- we found the right type of auth data in the request but the check failed - onUnauthenticated :: authData -> IO Response + onUnauthenticated :: authData -> IO ServantError } -- | concrete type to provide when in 'Strict' mode. @@ -119,10 +121,9 @@ basicAuthLax = laxProtect instance AuthData JWTAuth where authData req = do - -- We might want to write a proper parser for this? but split works fine... hdr <- lookup "Authorization" . requestHeaders $ req ["Bearer", token] <- return . splitOn " " . decodeUtf8 $ hdr - JWT.decode token -- try decode it. otherwise it's not a proper token + _ <- JWT.decode token -- try decode it. otherwise it's not a proper token return . JWTAuth $ token diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 4b27c688..70cd62b3 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeOperators #-} @@ -19,6 +20,8 @@ import Network.Wai (Application, Request, Response, ResponseReceived, requestBody, strictRequestBody) +import Servant.API.Authentication (AuthPolicy(Strict,Lax)) +import Servant.Server.Internal.Authentication import Servant.Server.Internal.ServantErr type RoutingApplication = @@ -155,32 +158,60 @@ data Delayed :: * -> * where Delayed :: IO (RouteResult a) -> IO (RouteResult ()) -> IO (RouteResult b) - -> (a -> b -> RouteResult c) - -> Delayed c + -> IO (RouteResult c) + -> (a -> b -> c -> RouteResult d) + -> Delayed d instance Functor Delayed where - fmap f (Delayed a b c g) = Delayed a b c ((fmap.fmap.fmap) f g) + fmap f (Delayed a b c d g) = Delayed a b c d ((fmap.fmap.fmap.fmap) f g) -- | Add a capture to the end of the capture block. addCapture :: Delayed (a -> b) -> IO (RouteResult a) -> Delayed b -addCapture (Delayed captures method body server) new = - Delayed (combineRouteResults (,) captures new) method body (\ (x, v) y -> ($ v) <$> server x y) +addCapture (Delayed captures method auth body server) new = + Delayed (combineRouteResults (,) captures new) method auth body (\ (x, v) y z -> ($ v) <$> server x y z) -- | Add a method check to the end of the method block. addMethodCheck :: Delayed a -> IO (RouteResult ()) -> Delayed a -addMethodCheck (Delayed captures method body server) new = - Delayed captures (combineRouteResults const method new) body server +addMethodCheck (Delayed captures method auth body server) new = + Delayed captures (combineRouteResults const method new) auth body server + +addAuthStrictCheck :: Delayed (AuthProtected auth usr (usr -> a) 'Strict) + -> IO (RouteResult (Maybe usr)) + -> Delayed a +addAuthStrictCheck delayed@(Delayed captures method _ body server) new = + let newAuth = runDelayed delayed `bindRouteResults` \ authProtectionStrict -> new `bindRouteResults` \ mAuthData -> case mAuthData of + + Nothing -> do + -- we're in strict mode: don't let the request go + -- call the provided "on missing auth" handler + resp <- onMissingAuthData (authHandlers authProtectionStrict) + return $ FailFatal resp + + -- successfully pulled auth data out of the request + Just authData -> do + mUsr <- (checkAuthStrict authProtectionStrict) authData + case mUsr of + -- this user is not authenticated + Nothing -> do + resp <- onUnauthenticated (authHandlers authProtectionStrict) authData + return $ FailFatal resp + + -- this user is authenticated + Just usr -> + (return . Route . subServerStrict authProtectionStrict) usr + in Delayed captures method newAuth body server + -- | Add a body check to the end of the body block. addBodyCheck :: Delayed (a -> b) -> IO (RouteResult a) -> Delayed b -addBodyCheck (Delayed captures method body server) new = - Delayed captures method (combineRouteResults (,) body new) (\ x (y, v) -> ($ v) <$> server x y) +addBodyCheck (Delayed captures method auth body server) new = + Delayed captures method auth (combineRouteResults (,) body new) (\ x (y, v) z -> ($ v) <$> server x y z) -- | Add an accept header check to the end of the body block. -- The accept header check should occur after the body check, @@ -189,8 +220,8 @@ addBodyCheck (Delayed captures method body server) new = addAcceptCheck :: Delayed a -> IO (RouteResult ()) -> Delayed a -addAcceptCheck (Delayed captures method body server) new = - Delayed captures method (combineRouteResults const body new) server +addAcceptCheck (Delayed captures method auth body server) new = + Delayed captures method auth (combineRouteResults const body new) server -- | Many combinators extract information that is passed to -- the handler without the possibility of failure. In such a diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 22c14175..1a917b10 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -67,6 +67,7 @@ import Servant.API.Authentication import Servant.Server.Internal.Authentication import Servant.Server (Server, serve, ServantErr(..), err404) import Servant.Server.Internal (RouteMismatch (..)) +import Web.JWT hiding (JSON) -- * test data types @@ -116,7 +117,8 @@ spec = do routerSpec responseHeadersSpec miscReqCombinatorsSpec - authRequiredSpec + basicAuthRequiredSpec + jwtAuthRequiredSpec type CaptureApi = Capture "legs" Integer :> Get '[JSON] Animal @@ -714,7 +716,7 @@ miscReqCombinatorsSpec = with (return $ serve miscApi miscServ) $ type AuthUser = ByteString type BasicAuthFooRealm = AuthProtect (BasicAuth "foo-realm") AuthUser 'Strict type BasicAuthBarRealm = AuthProtect (BasicAuth "bar-realm") AuthUser 'Strict -type AuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person +type BasicAuthRequiredAPI = BasicAuthFooRealm :> "foo" :> Get '[JSON] Person :<|> "bar" :> BasicAuthBarRealm :> Get '[JSON] Animal basicAuthFooCheck :: BasicAuth "foo-realm" -> IO (Maybe AuthUser) @@ -726,11 +728,11 @@ basicAuthBarCheck :: BasicAuth "bar-realm" -> IO (Maybe AuthUser) basicAuthBarCheck (BasicAuth usr pass) = if usr == "bar" && pass == "bar" then return (Just "bar") else return Nothing -authRequiredApi :: Proxy AuthRequiredAPI -authRequiredApi = Proxy +basicBasicAuthRequiredApi :: Proxy BasicAuthRequiredAPI +basicBasicAuthRequiredApi = Proxy -authRequiredServer :: Server AuthRequiredAPI -authRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice) +basicAuthRequiredServer :: Server BasicAuthRequiredAPI +basicAuthRequiredServer = basicAuthStrict basicAuthFooCheck (const . return $ alice) :<|> basicAuthStrict basicAuthBarCheck (const . return $ jerry) -- base64-encoded "servant:server" @@ -745,25 +747,25 @@ base64BarColonPassword = "YmFyOmJhcg==" base64UserColonPassword :: ByteString base64UserColonPassword = "dXNlcjpwYXNzd29yZA==" -authGet :: ByteString -> ByteString -> WaiSession SResponse -authGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] "" +basicAuthGet :: ByteString -> ByteString -> WaiSession SResponse +basicAuthGet path base64EncodedAuth = Test.Hspec.Wai.request methodGet path [("Authorization", "Basic " <> base64EncodedAuth)] "" -authRequiredSpec :: Spec -authRequiredSpec = do +basicAuthRequiredSpec :: Spec +basicAuthRequiredSpec = do describe "Servant.API.Authentication" $ do - with (return $ serve authRequiredApi authRequiredServer) $ do + with (return $ serve basicBasicAuthRequiredApi basicAuthRequiredServer) $ do it "allows access with the correct username and password" $ do - response <- authGet "/foo" base64ServantColonServer + response <- basicAuthGet "/foo" base64ServantColonServer liftIO $ do decode' (simpleBody response) `shouldBe` Just alice - response <- authGet "/bar" base64BarColonPassword + response <- basicAuthGet "/bar" base64BarColonPassword liftIO $ do decode' (simpleBody response) `shouldBe` Just jerry it "rejects requests with the incorrect username and password" $ do - authGet "/foo" base64UserColonPassword `shouldRespondWith` 401 - authGet "/bar" base64UserColonPassword `shouldRespondWith` 401 + basicAuthGet "/foo" base64UserColonPassword `shouldRespondWith` 401 + basicAuthGet "/bar" base64UserColonPassword `shouldRespondWith` 401 it "does not respond to non-authenticated requests" $ do get "/foo" `shouldRespondWith` 401 @@ -774,3 +776,42 @@ authRequiredSpec = do bar401 <- get "/bar" WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"foo-realm\"" foo401) WaiSession (assertHeader "WWW-Authenticate" "Basic realm=\"bar-realm\"" bar401) + + + +type JWTAuthProtect = AuthProtect JWTAuth (JWT VerifiedJWT) 'Strict + +type JWTAuthRequiredAPI = JWTAuthProtect :> "foo" :> Get '[JSON] Person + + +jwtAuthRequiredApi :: Proxy JWTAuthRequiredAPI +jwtAuthRequiredApi = Proxy + +jwtAuthRequiredServer :: Server JWTAuthRequiredAPI +jwtAuthRequiredServer = jwtAuthStrict (secret "secret") (const . return $ alice) + +correctToken = "blah" +incorrectToken = "blah" + +jwtAuthGet :: ByteString -> ByteString -> WaiSession SResponse +jwtAuthGet path token = Test.Hspec.Wai.request methodGet path [("Authorization", "Bearer " <> token)] "" + +jwtAuthRequiredSpec :: Spec +jwtAuthRequiredSpec = do + describe "JWT Auth" $ do + with (return $ serve jwtAuthRequiredApi jwtAuthRequiredServer) $ do + it "allows access with the correct token" $ do + response <- jwtAuthGet "/foo" correctToken + liftIO $ do + decode' (simpleBody response) `shouldBe` Just alice + it "rejects requests with an incorrect token" $ do + jwtAuthGet "/foo" incorrectToken `shouldRespondWith` 401 + it "rejects requests without auth data" $ do + get "/foo" `shouldRespondWith` 401 + it "responds correctly to requests without auth data" $ do + a <- jwtAuthGet "/foo" incorrectToken + WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_token\"" a) + it "respond correctly to requests with incorrect auth data" $ do + a <- get "/foo" + WaiSession (assertHeader "WWW-Authenticate" "Bearer error=\"invalid_request\"" a) +