From 0461c4642d669b6f0478767cfd9512cdde21e710 Mon Sep 17 00:00:00 2001 From: aaron levin Date: Wed, 17 Feb 2016 21:21:57 +0100 Subject: [PATCH] Add gen. authentication support to servant-server --- servant-server/servant-server.cabal | 1 + servant-server/src/Servant/Server.hs | 6 ++- servant-server/src/Servant/Server/Internal.hs | 23 +++++++++- .../src/Servant/Server/Internal/Auth.hs | 27 ++++++++++++ servant-server/test/Servant/ServerSpec.hs | 44 +++++++++++++++++-- 5 files changed, 95 insertions(+), 6 deletions(-) create mode 100644 servant-server/src/Servant/Server/Internal/Auth.hs diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index 2aa25cee..f15e7a45 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -38,6 +38,7 @@ library Servant.Server Servant.Server.Internal Servant.Server.Internal.Context + Servant.Server.Internal.Auth Servant.Server.Internal.BasicAuth Servant.Server.Internal.Enter Servant.Server.Internal.Router diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index 6b37297e..c88b1375 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -45,11 +45,15 @@ module Servant.Server , NamedContext(..) , descendIntoNamedContext - -- * Basic Authentication , BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck) , BasicAuthResult(..) + -- * General Authentication + , AuthHandler(unAuthHandler) + , AuthServerData + , mkAuthHandler + -- * Default error type , ServantErr(..) -- ** 3XX diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index ea89b0a0..37955122 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -10,11 +10,13 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} #include "overlapping-compat.h" module Servant.Server.Internal ( module Servant.Server.Internal + , module Servant.Server.Internal.Auth , module Servant.Server.Internal.Context , module Servant.Server.Internal.BasicAuth , module Servant.Server.Internal.Router @@ -25,7 +27,7 @@ module Servant.Server.Internal #if !MIN_VERSION_base(4,8,0) import Control.Applicative ((<$>)) #endif -import Control.Monad.Trans.Except (ExceptT) +import Control.Monad.Trans.Except (ExceptT, runExceptT) import qualified Data.ByteString as B import qualified Data.ByteString.Char8 as BC8 import qualified Data.ByteString.Lazy as BL @@ -50,7 +52,7 @@ import Web.HttpApiData.Internal (parseHeaderMaybe, parseQueryParamMaybe, parseUrlPieceMaybe) -import Servant.API ((:<|>) (..), (:>), BasicAuth, Capture, +import Servant.API ((:<|>) (..), (:>), AuthProtect, BasicAuth, Capture, Verb, ReflectMethod(reflectMethod), IsSecure(..), Header, QueryFlag, QueryParam, QueryParams, @@ -64,6 +66,7 @@ import Servant.API.ContentTypes (AcceptHeader (..), import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, getResponse) +import Servant.Server.Internal.Auth import Servant.Server.Internal.Context import Servant.Server.Internal.BasicAuth import Servant.Server.Internal.Router @@ -482,6 +485,22 @@ pathIsEmpty = go . pathInfo ct_wildcard :: B.ByteString ct_wildcard = "*" <> "/" <> "*" -- Because CPP +-- * General Authentication + +instance ( HasServer api context + , HasContextEntry context (AuthHandler Request (AuthServerData (AuthProtect tag))) + ) + => HasServer (AuthProtect tag :> api) context where + + type ServerT (AuthProtect tag :> api) m = + AuthServerData (AuthProtect tag) -> ServerT api m + + route Proxy context subserver = WithRequest $ \ request -> + route (Proxy :: Proxy api) context (subserver `addAuthCheck` authCheck request) + where + authHandler = unAuthHandler (getContextEntry context) + authCheck = fmap (either FailFatal Route) . runExceptT . authHandler + -- * contexts instance (HasContextEntry context (NamedContext name subContext), HasServer subApi subContext) diff --git a/servant-server/src/Servant/Server/Internal/Auth.hs b/servant-server/src/Servant/Server/Internal/Auth.hs new file mode 100644 index 00000000..e9c69db8 --- /dev/null +++ b/servant-server/src/Servant/Server/Internal/Auth.hs @@ -0,0 +1,27 @@ +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} + +module Servant.Server.Internal.Auth where + +import Control.Monad.Trans.Except (ExceptT) +import Data.Typeable (Typeable) +import GHC.Generics (Generic) + +import Servant.Server.Internal.ServantErr (ServantErr) + +-- * General Auth + +-- | Specify the type of data returned after we've authenticated a request. +-- quite often this is some `User` datatype. +type family AuthServerData a :: * + +-- | Handlers for AuthProtected resources +newtype AuthHandler r usr = AuthHandler + { unAuthHandler :: r -> ExceptT ServantErr IO usr } + deriving (Generic, Typeable) + +mkAuthHandler :: (r -> ExceptT ServantErr IO usr) -> AuthHandler r usr +mkAuthHandler = AuthHandler diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 0524a11a..04e6f407 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -31,14 +31,15 @@ import Network.HTTP.Types (Status (..), hAccept, hContentType, methodHead, methodPatch, methodPost, methodPut, ok200, parseQuery) -import Network.Wai (Application, Request, pathInfo, +import Network.Wai (Application, Request, requestHeaders, pathInfo, queryString, rawQueryString, responseBuilder, responseLBS) import Network.Wai.Internal (Response (ResponseBuilder)) import Network.Wai.Test (defaultRequest, request, runSession, simpleBody, simpleHeaders, simpleStatus) -import Servant.API ((:<|>) (..), (:>), BasicAuth, BasicAuthData(BasicAuthData), +import Servant.API ((:<|>) (..), (:>), AuthProtect, + BasicAuth, BasicAuthData(BasicAuthData), Capture, Delete, Get, Header (..), Headers, HttpVersion, IsSecure (..), JSON, @@ -59,6 +60,9 @@ import Test.Hspec.Wai (get, liftIO, matchHeaders, import Servant.Server.Internal.BasicAuth (BasicAuthCheck(BasicAuthCheck), BasicAuthResult(Authorized,Unauthorized)) +import Servant.Server.Internal.Auth + (AuthHandler, AuthServerData, + mkAuthHandler) import Servant.Server.Internal.RoutingApplication (toApplication, RouteResult(..)) import Servant.Server.Internal.Router @@ -90,6 +94,7 @@ spec = do routerSpec miscCombinatorSpec basicAuthSpec + genAuthSpec ------------------------------------------------------------------------------ -- * verbSpec {{{ @@ -534,7 +539,7 @@ miscCombinatorSpec = with (return $ serve miscApi miscServ) $ -- }}} ------------------------------------------------------------------------------ --- * Authentication {{{ +-- * Basic Authentication {{{ ------------------------------------------------------------------------------ type BasicAuthAPI = BasicAuth "foo" () :> "basic" :> Get '[JSON] Animal @@ -564,6 +569,39 @@ basicAuthSpec = do it "returns 200 with the right password" $ do THW.request methodGet "/basic" [("Authorization","Basic c2VydmFudDpzZXJ2ZXI=")] "" `shouldRespondWith` 200 +-- }}} +------------------------------------------------------------------------------ +-- * General Authentication {{{ +------------------------------------------------------------------------------ + +type GenAuthAPI = AuthProtect "auth" :> "auth" :> Get '[JSON] Animal +authApi :: Proxy GenAuthAPI +authApi = Proxy +authServer :: Server GenAuthAPI +authServer = const (return tweety) + +type instance AuthServerData (AuthProtect "auth") = () + +genAuthContext :: Context '[ AuthHandler Request () ] +genAuthContext = + let authHandler = (\req -> + if elem ("Auth", "secret") (requestHeaders req) + then return () + else throwE err401 + ) + in mkAuthHandler authHandler :. EmptyContext + +genAuthSpec :: Spec +genAuthSpec = do + describe "Servant.API.Auth" $ do + with (return (serveWithContext authApi genAuthContext authServer)) $ do + + context "Custom Auth Protection" $ do + it "returns 401 when missing headers" $ do + get "/auth" `shouldRespondWith` 401 + it "returns 200 with the right header" $ do + THW.request methodGet "/auth" [("Auth","secret")] "" `shouldRespondWith` 200 + -- }}} ------------------------------------------------------------------------------ -- * Test data types {{{