diff --git a/servant-client/servant-client.cabal b/servant-client/servant-client.cabal index 71cb2ee6..10626aeb 100644 --- a/servant-client/servant-client.cabal +++ b/servant-client/servant-client.cabal @@ -27,6 +27,7 @@ source-repository head library exposed-modules: Servant.Client + Servant.Common.Auth Servant.Common.BaseUrl Servant.Common.Req build-depends: @@ -34,6 +35,7 @@ library , aeson , attoparsec , bytestring + , base64-bytestring , exceptions , http-api-data >= 0.1 && < 0.3 , http-client @@ -68,6 +70,7 @@ test-suite spec , transformers-compat , aeson , bytestring + , base64-bytestring , deepseq , hspec == 2.* , http-client diff --git a/servant-client/src/Servant/Client.hs b/servant-client/src/Servant/Client.hs index 82779651..6ccc1777 100644 --- a/servant-client/src/Servant/Client.hs +++ b/servant-client/src/Servant/Client.hs @@ -15,8 +15,12 @@ -- querying functions for each endpoint just from the type representing your -- API. module Servant.Client - ( client + ( AuthClientData + , AuthenticateReq(..) + , BasicAuthData(..) + , client , HasClient(..) + , mkAuthenticateReq , ServantError(..) , module Servant.Common.BaseUrl ) where @@ -36,6 +40,7 @@ import Network.HTTP.Media import qualified Network.HTTP.Types as H import qualified Network.HTTP.Types.Header as HTTP import Servant.API +import Servant.Common.Auth import Servant.Common.BaseUrl import Servant.Common.Req @@ -423,6 +428,20 @@ instance HasClient subapi => type Client (WithNamedConfig name config subapi) = Client subapi clientWithRoute Proxy = clientWithRoute (Proxy :: Proxy subapi) +instance HasClient api => HasClient (BasicAuth realm :> api) where + type Client (BasicAuth realm :> api) = BasicAuthData -> Client api + + clientWithRoute Proxy req baseurl manager val = + clientWithRoute (Proxy :: Proxy api) (basicAuthReq val req) baseurl manager + +instance ( HasClient api + ) => HasClient (AuthProtect tag :> api) where + type Client (AuthProtect tag :> api) + = AuthenticateReq (AuthProtect tag) -> Client api + + clientWithRoute Proxy req baseurl manager (AuthenticateReq (val,func)) = + clientWithRoute (Proxy :: Proxy api) (func val req) baseurl manager + {- Note [Non-Empty Content Types] ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/servant-client/src/Servant/Common/Auth.hs b/servant-client/src/Servant/Common/Auth.hs new file mode 100644 index 00000000..a8502be2 --- /dev/null +++ b/servant-client/src/Servant/Common/Auth.hs @@ -0,0 +1,49 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE TypeFamilies #-} + +-- | Authentication for clients + +module Servant.Common.Auth ( + AuthenticateReq(AuthenticateReq, unAuthReq) + , AuthClientData + , BasicAuthData (BasicAuthData, username, password) + , basicAuthReq + , mkAuthenticateReq + ) where + +import Data.ByteString (ByteString) +import Data.ByteString.Base64 (encode) +import Data.Monoid ((<>)) +import Data.Text.Encoding (decodeUtf8) +import Servant.Common.Req (addHeader, Req) + + +-- | A simple datatype to hold data required to decorate a request +data BasicAuthData = BasicAuthData { username :: ByteString + , password :: ByteString + } + +-- | Authenticate a request using Basic Authentication +basicAuthReq :: BasicAuthData -> Req -> Req +basicAuthReq (BasicAuthData user pass) req = + let authText = decodeUtf8 ("Basic " <> encode (user <> ":" <> pass)) + in addHeader "Authorization" authText req + +-- | For a resource protected by authentication (e.g. AuthProtect), we need +-- to provide the client with some data used to add authentication data +-- to a request +type family AuthClientData a :: * + +-- | For better type inference and to avoid usage of a data family, we newtype +-- wrap the combination of some 'AuthClientData' and a function to add authentication +-- data to a request +newtype AuthenticateReq a = + AuthenticateReq { unAuthReq :: (AuthClientData a, AuthClientData a -> Req -> Req) } + +-- | Handy helper to avoid wrapping datatypes in tuples everywhere. +mkAuthenticateReq :: AuthClientData a + -> (AuthClientData a -> Req -> Req) + -> AuthenticateReq a +mkAuthenticateReq val func = AuthenticateReq (val, func) + diff --git a/servant-client/test/Servant/ClientSpec.hs b/servant-client/test/Servant/ClientSpec.hs index 4cb1ef4c..04b7e55b 100644 --- a/servant-client/test/Servant/ClientSpec.hs +++ b/servant-client/test/Servant/ClientSpec.hs @@ -13,6 +13,7 @@ {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -fcontext-stack=100 #-} {-# OPTIONS_GHC -fno-warn-orphans #-} @@ -28,7 +29,7 @@ import Control.Arrow (left) import Control.Concurrent (forkIO, killThread, ThreadId) import Control.Exception (bracket) import Control.Monad.Trans.Except (ExceptT, runExceptT, throwE) -import Data.Aeson +import Data.Aeson hiding ((.:)) import Data.Char (chr, isPrint) import Data.Foldable (forM_) import Data.Monoid hiding (getLast) @@ -41,7 +42,7 @@ import Network.HTTP.Media import Network.HTTP.Types (Status (..), badRequest400, methodGet, ok200, status400) import Network.Socket -import Network.Wai (Application, responseLBS) +import Network.Wai (Application, Request, requestHeaders, responseLBS) import Network.Wai.Handler.Warp import System.IO.Unsafe (unsafePerformIO) import Test.Hspec @@ -53,6 +54,7 @@ import Servant.API import Servant.API.Internal.Test.ComprehensiveAPI import Servant.Client import Servant.Server +import qualified Servant.Common.Req as SCR -- This declaration simply checks that all instances are in place. _ = client comprehensiveAPI @@ -62,6 +64,7 @@ spec = describe "Servant.Client" $ do sucessSpec failSpec wrappedApiSpec + authSpec -- * test data types @@ -111,9 +114,11 @@ type Api = Get '[JSON] (String, Maybe Int, Bool, [(String, [Rational])]) :<|> "headers" :> Get '[JSON] (Headers TestHeaders Bool) :<|> "deleteContentType" :> DeleteNoContent '[JSON] NoContent + api :: Proxy Api api = Proxy + server :: Application server = serve api EmptyConfig ( return alice @@ -148,6 +153,46 @@ failServer = serve failApi EmptyConfig ( :<|> (\_request respond -> respond $ responseLBS ok200 [("content-type", "fooooo")] "") ) +-- auth stuff +type AuthAPI = + BasicAuth "foo-realm" :> "private" :> "basic" :> Get '[JSON] Person + :<|> AuthProtect "auth-tag" :> "private" :> "auth" :> Get '[JSON] Person + +authAPI :: Proxy AuthAPI +authAPI = Proxy + +type instance AuthReturnType (BasicAuth "foo-realm") = () +type instance AuthReturnType (AuthProtect "auth-tag") = () +type instance AuthClientData (AuthProtect "auth-tag") = () + +basicAuthHandler :: BasicAuthCheck () +basicAuthHandler = + let check username password = + if username == "servant" && password == "server" + then return (Authorized ()) + else return Unauthorized + in BasicAuthCheck check + +authHandler :: AuthHandler Request () +authHandler = + let handler req = case lookup "AuthHeader" (requestHeaders req) of + Nothing -> throwE (err401 { errBody = "Missing auth header" + , errReasonPhrase = "denied!" + }) + Just _ -> return () + in mkAuthHandler handler + +serverConfig :: Config '[ BasicAuthCheck () + , AuthHandler Request () + ] +serverConfig = basicAuthHandler :. authHandler :. EmptyConfig + +authServer :: Application +authServer = serve authAPI serverConfig (const (return alice) :<|> const (return alice)) + +{- + -} + {-# NOINLINE manager #-} manager :: C.Manager manager = unsafePerformIO $ C.newManager C.defaultManagerSettings @@ -287,14 +332,41 @@ failSpec = beforeAll (startWaiApp failServer) $ afterAll endWaiApp $ do InvalidContentTypeHeader "fooooo" _ -> return () _ -> fail $ "expected InvalidContentTypeHeader, but got " <> show res +authSpec :: Spec +authSpec = beforeAll (startWaiApp authServer) $ afterAll endWaiApp $ do + context "Authentication works when requests are properly authenticated" $ do + + it "Authenticates a BasicAuth protected server appropriately" $ \(_,baseUrl) -> do + let (getBasic :<|> _) = client authAPI baseUrl manager + let authData = BasicAuthData "servant" "server" + (left show <$> runExceptT (getBasic authData)) `shouldReturn` Right alice + + it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do + let (_ :<|> getProtected) = client authAPI baseUrl manager + let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "AuthHeader" ("cool" :: String) req) + (left show <$> runExceptT (getProtected authRequest)) `shouldReturn` Right alice + + context "Authentication is rejected when requests are not authenticated properly" $ do + + it "Authenticates a BasicAuth protected server appropriately" $ \(_,baseUrl) -> do + let (getBasic :<|> _) = client authAPI baseUrl manager + let authData = BasicAuthData "not" "password" + Left FailureResponse{..} <- runExceptT (getBasic authData) + responseStatus `shouldBe` Status 403 "Forbidden" + + it "Authenticates a AuthProtect protected server appropriately" $ \(_, baseUrl) -> do + let (_ :<|> getProtected) = client authAPI baseUrl manager + let authRequest = mkAuthenticateReq () (\_ req -> SCR.addHeader "Wrong" ("header" :: String) req) + Left FailureResponse{..} <- runExceptT (getProtected authRequest) + responseStatus `shouldBe` (Status 401 "denied") + +-- * utils + data WrappedApi where WrappedApi :: (HasServer (api :: *) '[], Server api ~ ExceptT ServantErr IO a, HasClient api, Client api ~ ExceptT ServantError IO ()) => Proxy api -> WrappedApi - --- * utils - startWaiApp :: Application -> IO (ThreadId, BaseUrl) startWaiApp app = do (port, socket) <- openTestSocket diff --git a/servant-docs/src/Servant/Docs/Internal.hs b/servant-docs/src/Servant/Docs/Internal.hs index 70f8954c..32674f32 100644 --- a/servant-docs/src/Servant/Docs/Internal.hs +++ b/servant-docs/src/Servant/Docs/Internal.hs @@ -22,7 +22,7 @@ module Servant.Docs.Internal where import Control.Applicative import Control.Arrow (second) -import Control.Lens (makeLenses, over, traversed, (%~), +import Control.Lens (makeLenses, mapped, over, traversed, view, (%~), (&), (.~), (<>~), (^.), (|>)) import qualified Control.Monad.Omega as Omega import Data.ByteString.Conversion (ToByteString, toByteString) @@ -140,6 +140,12 @@ data DocIntro = DocIntro , _introBody :: [String] -- ^ Each String is a paragraph. } deriving (Eq, Show) +-- | A type to represent Authentication information about an endpoint. +data DocAuthentication = DocAuthentication + { _authIntro :: String + , _authDataRequired :: String + } deriving (Eq, Ord, Show) + instance Ord DocIntro where compare = comparing _introTitle @@ -230,7 +236,8 @@ defResponse = Response -- You can tweak an 'Action' (like the default 'defAction') with these lenses -- to transform an action and add some information to it. data Action = Action - { _captures :: [DocCapture] -- type collected + user supplied info + { _authInfo :: [DocAuthentication] -- user supplied info + , _captures :: [DocCapture] -- type collected + user supplied info , _headers :: [Text] -- type collected , _params :: [DocQueryParam] -- type collected + user supplied info , _notes :: [DocNote] -- user supplied @@ -247,8 +254,8 @@ data Action = Action -- 'combineAction' to mush two together taking the response, body and content -- types from the very left. combineAction :: Action -> Action -> Action -Action c h p n m ts body resp `combineAction` Action c' h' p' n' m' _ _ _ = - Action (c <> c') (h <> h') (p <> p') (n <> n') (m <> m') ts body resp +Action a c h p n m ts body resp `combineAction` Action a' c' h' p' n' m' _ _ _ = + Action (a <> a') (c <> c') (h <> h') (p <> p') (n <> n') (m <> m') ts body resp -- Default 'Action'. Has no 'captures', no GET 'params', expects -- no request body ('rqbody') and the typical response is 'defResponse'. @@ -268,6 +275,7 @@ defAction = [] [] [] + [] defResponse -- | Create an API that's comprised of a single endpoint. @@ -277,6 +285,7 @@ single :: Endpoint -> Action -> API single e a = API mempty (HM.singleton e a) -- gimme some lenses +makeLenses ''DocAuthentication makeLenses ''DocOptions makeLenses ''API makeLenses ''Endpoint @@ -454,7 +463,7 @@ instance AllHeaderSamples '[] where instance (ToByteString l, AllHeaderSamples ls, ToSample l, KnownSymbol h) => AllHeaderSamples (Header h l ': ls) where - allHeaderToSample _ = (mkHeader (toSample (Proxy :: Proxy l))) : + allHeaderToSample _ = mkHeader (toSample (Proxy :: Proxy l)) : allHeaderToSample (Proxy :: Proxy ls) where headerName = CI.mk . cs $ symbolVal (Proxy :: Proxy h) mkHeader (Just x) = (headerName, cs $ toByteString x) @@ -504,6 +513,10 @@ class ToParam t where class ToCapture c where toCapture :: Proxy c -> DocCapture +-- | The class that helps us get documentation for authenticated endpoints +class ToAuthInfo a where + toAuthInfo :: Proxy a -> DocAuthentication + -- | Generate documentation in Markdown format for -- the given 'API'. markdown :: API -> String @@ -516,6 +529,7 @@ markdown api = unlines $ str : "" : notesStr (action ^. notes) ++ + authStr (action ^. authInfo) ++ capturesStr (action ^. captures) ++ headersStr (action ^. headers) ++ paramsStr (action ^. params) ++ @@ -548,6 +562,20 @@ markdown api = unlines $ "" : [] + + authStr :: [DocAuthentication] -> [String] + authStr auths = + let authIntros = mapped %~ view authIntro $ auths + clientInfos = mapped %~ view authDataRequired $ auths + in "#### Authentication": + "": + unlines authIntros : + "": + "Clients must supply the following data" : + unlines clientInfos : + "" : + [] + capturesStr :: [DocCapture] -> [String] capturesStr [] = [] capturesStr l = @@ -797,6 +825,20 @@ instance HasDocs sublayout => HasDocs (Vault :> sublayout) where instance HasDocs sublayout => HasDocs (WithNamedConfig name config sublayout) where docsFor Proxy = docsFor (Proxy :: Proxy sublayout) +instance (ToAuthInfo (BasicAuth realm), HasDocs sublayout) => HasDocs (BasicAuth realm :> sublayout) where + docsFor Proxy (endpoint, action) = + docsFor (Proxy :: Proxy sublayout) (endpoint, action') + where + authProxy = Proxy :: Proxy (BasicAuth realm) + action' = over authInfo (|> toAuthInfo authProxy) action + +instance (ToAuthInfo (AuthProtect tag), HasDocs sublayout) => HasDocs (AuthProtect tag :> sublayout) where + docsFor Proxy (endpoint, action) = + docsFor (Proxy :: Proxy sublayout) (endpoint, action') + where + authProxy = Proxy :: Proxy (AuthProtect tag) + action' = over authInfo (|> toAuthInfo authProxy) action + -- ToSample instances for simple types instance ToSample () instance ToSample Bool diff --git a/servant-examples/auth-combinator/auth-combinator.hs b/servant-examples/auth-combinator/auth-combinator.hs index f2cebb4f..9773be83 100644 --- a/servant-examples/auth-combinator/auth-combinator.hs +++ b/servant-examples/auth-combinator/auth-combinator.hs @@ -9,56 +9,39 @@ {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -import Data.Aeson -import Data.ByteString (ByteString) -import Data.IORef -import Data.Text (Text) +import Control.Monad.Trans.Except (ExceptT, throwE) +import Data.Aeson hiding ((.:)) +import Data.ByteString (ByteString) +import Data.Monoid ((<>)) +import Data.Text (Text) import GHC.Generics import Network.Wai import Network.Wai.Handler.Warp import Servant -import Servant.Server.Internal --- Pretty much stolen/adapted from --- https://github.com/haskell-servant/HaskellSGMeetup2015/blob/master/examples/authentication-combinator/AuthenticationCombinator.hs +-- | A user type that we "fetch from the database" after +-- performing authentication +newtype User = User { unUser :: Text } -type DBConnection = IORef [ByteString] -type DBLookup = DBConnection -> ByteString -> IO Bool -initDB :: IO DBConnection -initDB = newIORef ["good password"] +-- | A method that, when given a password, will return a User. +-- This is our bespoke (and bad) authentication logic. +lookupUser :: ByteString -> ExceptT ServantErr IO User +lookupUser cookie = + if cookie == "good password" + then return (User "user") + else throwE (err403 { errBody = "Invalid Cookie" }) -isGoodCookie :: DBLookup -isGoodCookie ref password = do - allowed <- readIORef ref - return (password `elem` allowed) - -data AuthProtected - -instance (HasConfigEntry config DBConnection, HasServer rest config) - => HasServer (AuthProtected :> rest) config where - - type ServerT (AuthProtected :> rest) m = ServerT rest m - - route Proxy config subserver = WithRequest $ \ request -> - route (Proxy :: Proxy rest) config $ addAcceptCheck subserver $ cookieCheck request - where - cookieCheck req = case lookup "Cookie" (requestHeaders req) of - Nothing -> return $ FailFatal err401 { errBody = "Missing auth header" } - Just v -> do - let dbConnection = getConfigEntry config - authGranted <- isGoodCookie dbConnection v - if authGranted - then return $ Route () - else return $ FailFatal err403 { errBody = "Invalid cookie" } - -type PrivateAPI = Get '[JSON] [PrivateData] - -type PublicAPI = Get '[JSON] [PublicData] - -type API = "private" :> AuthProtected :> PrivateAPI - :<|> PublicAPI +-- | The auth handler wraps a function from Request -> ExceptT ServantErr IO User +-- we look for a Cookie and pass the value of the cookie to `lookupUser`. +authHandler :: AuthHandler Request User +authHandler = + let handler req = case lookup "Cookie" (requestHeaders req) of + Nothing -> throwE (err401 { errBody = "Missing auth header" }) + Just cookie -> lookupUser cookie + in mkAuthHandler handler +-- | Data types that will be returned from various api endpoints newtype PrivateData = PrivateData { ssshhh :: Text } deriving (Eq, Show, Generic) @@ -69,28 +52,54 @@ newtype PublicData = PublicData { somedata :: Text } instance ToJSON PublicData +-- | Our private API that we want to be auth-protected. +type PrivateAPI = Get '[JSON] [PrivateData] + +-- | Our public API that doesn't have any protection +type PublicAPI = Get '[JSON] [PublicData] + +-- | Our API, with auth-protection +type API = "private" :> AuthProtect "cookie-auth" :> PrivateAPI + :<|> "public" :> PublicAPI + +-- | A value holding our type-level API api :: Proxy API api = Proxy -server :: Server API -server = return prvdata :<|> return pubdata +-- | We need to specify the data returned after authentication +type instance AuthReturnType (AuthProtect "cookie-auth") = User - where prvdata = [PrivateData "this is a secret"] +-- | The configuration that will be made available to request handlers. We supply the +-- "cookie-auth"-tagged request handler defined above, so that the 'HasServer' instance +-- of 'AuthProtect' can extract the handler and run it on the request. +serverConfig :: Config (AuthHandler Request User ': '[]) +serverConfig = authHandler :. EmptyConfig + +-- | Our API, where we provide all the author-supplied handlers for each end point. +-- note that 'prvdata' is a function that takes 'User' as an argument. We dont' worry +-- about the authentication instrumentation here, that is taken care of by supplying +-- configuration +server :: Server API +server = prvdata :<|> return pubdata + + where prvdata (User name) = return [PrivateData ("this is a secret: " <> name)] pubdata = [PublicData "this is a public piece of data"] +-- | run our server main :: IO () -main = do - dbConnection <- initDB - let config = dbConnection :. EmptyConfig - run 8080 (serve api config server) +main = run 8080 (serve api serverConfig server) -{- Sample session: -$ curl http://localhost:8080/ +{- Sample Session: + +$ curl -XGET localhost:8080/private +Missing auth header + +$ curl -XGET localhost:8080/private -H "Cookie: good password" +[{"ssshhh":"this is a secret: user"}] + +$ curl -XGET localhost:8080/private -H "Cookie: bad password" +Invalid Cookie + +$ curl -XGET localhost:8080/public [{"somedata":"this is a public piece of data"}] -$ curl http://localhost:8080/private -Missing auth header. -$ curl -H "Cookie: good password" http://localhost:8080/private -[{"ssshhh":"this is a secret"}] -$ curl -H "Cookie: bad password" http://localhost:8080/private -Invalid cookie. -} diff --git a/servant-examples/basic-auth/basic-auth.hs b/servant-examples/basic-auth/basic-auth.hs new file mode 100644 index 00000000..c409f6ca --- /dev/null +++ b/servant-examples/basic-auth/basic-auth.hs @@ -0,0 +1,108 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +module Main where + +import Data.Aeson (ToJSON) +import Data.Proxy (Proxy (Proxy)) +import Data.Text (Text) +import GHC.Generics (Generic) +import Network.Wai.Handler.Warp (run) +import Servant.API ((:<|>) ((:<|>)), (:>), BasicAuth, + Get, JSON) +import Servant.Server (AuthReturnType, BasicAuthResult (Authorized, Unauthorized), Config ((:.), EmptyConfig), + Server, serve, BasicAuthCheck(BasicAuthCheck)) + +-- | let's define some types that our API returns. + +-- | private data that needs protection +newtype PrivateData = PrivateData { ssshhh :: Text } + deriving (Eq, Show, Generic) + +instance ToJSON PrivateData + +-- | public data that anyone can use. +newtype PublicData = PublicData { somedata :: Text } + deriving (Eq, Show, Generic) + +instance ToJSON PublicData + +-- | A user we'll grab from the database when we authenticate someone +newtype User = User { userName :: Text } + deriving (Eq, Show) + +-- | a type to wrap our public api +type PublicAPI = Get '[JSON] [PublicData] + +-- | a type to wrap our private api +type PrivateAPI = Get '[JSON] PrivateData + +-- | our API +type API = "public" :> PublicAPI + :<|> "private" :> BasicAuth "foo-realm" :> PrivateAPI + +-- | a value holding a proxy of our API type +api :: Proxy API +api = Proxy + +-- | a value holding a proxy of our basic auth realm. +authRealm :: Proxy "foo-realm" +authRealm = Proxy + +-- | Specify the data type returned after performing basic authentication +type instance AuthReturnType (BasicAuth "foo-realm") = User + +-- | 'BasicAuthCheck' holds the handler we'll use to verify a username and password. +authCheck :: BasicAuthCheck User +authCheck = + let check username password = + if username == "servant" && password == "server" + then return (Authorized (User "servant")) + else return Unauthorized + in BasicAuthCheck check + +-- | We need to supply our handlers with the right configuration. In this case, +-- Basic Authentication requires a Config Entry with the 'BasicAuthCheck' value +-- tagged with "foo-tag" This config is then supplied to 'server' and threaded +-- to the BasicAuth HasServer handlers. +serverConfig :: Config (BasicAuthCheck User ': '[]) +serverConfig = authCheck :. EmptyConfig + +-- | an implementation of our server. Here is where we pass all the handlers to our endpoints. +-- In particular, for the BasicAuth protected handler, we need to supply a function +-- that takes 'User' as an argument. +server :: Server API +server = + let publicAPIHandler = return [PublicData "foo", PublicData "bar"] + privateAPIHandler (user :: User) = return (PrivateData (userName user)) + in publicAPIHandler :<|> privateAPIHandler + +-- | hello, server! +main :: IO () +main = run 8080 (serve api serverConfig server) + +{- Sample session + +$ curl -XGET localhost:8080/public +[{"somedata":"foo"},{"somedata":"bar"} + +$ curl -iXGET localhost:8080/private +HTTP/1.1 401 Unauthorized +transfer-encoding: chunked +Date: Thu, 07 Jan 2016 22:36:38 GMT +Server: Warp/3.1.8 +WWW-Authenticate: Basic realm="foo-realm" + +$ curl -iXGET localhost:8080/private -H "Authorization: Basic c2VydmFudDpzZXJ2ZXI=" +HTTP/1.1 200 OK +transfer-encoding: chunked +Date: Thu, 07 Jan 2016 22:37:58 GMT +Server: Warp/3.1.8 +Content-Type: application/json + +{"ssshhh":"servant"} +-} diff --git a/servant-examples/servant-examples.cabal b/servant-examples/servant-examples.cabal index d62c01c7..d00ce302 100644 --- a/servant-examples/servant-examples.cabal +++ b/servant-examples/servant-examples.cabal @@ -100,11 +100,28 @@ executable auth-combinator , servant == 0.5.* , servant-server == 0.5.* , text + , transformers , wai , warp hs-source-dirs: auth-combinator default-language: Haskell2010 +executable basic-auth + main-is: basic-auth.hs + ghc-options: -Wall -fno-warn-unused-binds -fno-warn-name-shadowing + build-depends: + aeson >= 0.8 + , base >= 4.7 && < 5 + , bytestring + , http-types + , servant == 0.5.* + , servant-server == 0.5.* + , text + , wai + , warp + hs-source-dirs: basic-auth + default-language: Haskell2010 + executable socket-io-chat main-is: socket-io-chat.hs ghc-options: -Wall -fno-warn-unused-binds -fno-warn-name-shadowing diff --git a/servant-server/servant-server.cabal b/servant-server/servant-server.cabal index f6ed6319..267094f2 100644 --- a/servant-server/servant-server.cabal +++ b/servant-server/servant-server.cabal @@ -37,6 +37,7 @@ library Servant Servant.Server Servant.Server.Internal + Servant.Server.Internal.Auth Servant.Server.Internal.Config Servant.Server.Internal.Enter Servant.Server.Internal.Router @@ -47,6 +48,7 @@ library base >= 4.7 && < 5 , aeson >= 0.7 && < 0.11 , attoparsec >= 0.12 && < 0.14 + , base64-bytestring == 1.0.* , bytestring >= 0.10 && < 0.11 , containers >= 0.5 && < 0.6 , http-api-data >= 0.1 && < 0.3 @@ -67,6 +69,7 @@ library , wai >= 3.0 && < 3.3 , wai-app-static >= 3.0 && < 3.2 , warp >= 3.0 && < 3.3 + , word8 == 0.1.* hs-source-dirs: src default-language: Haskell2010 diff --git a/servant-server/src/Servant/Server.hs b/servant-server/src/Servant/Server.hs index ea78a969..340628be 100644 --- a/servant-server/src/Servant/Server.hs +++ b/servant-server/src/Servant/Server.hs @@ -43,6 +43,15 @@ module Servant.Server , NamedConfig(..) , descendIntoNamedConfig + -- * General Authentication + , AuthHandler(unAuthHandler) + , AuthReturnType + , mkAuthHandler + + -- * Basic Authentication + , BasicAuthCheck(BasicAuthCheck, unBasicAuthCheck) + , BasicAuthResult(..) + -- * Default error type , ServantErr(..) -- ** 3XX @@ -117,7 +126,7 @@ serve :: (HasServer layout config) => Proxy layout -> Config config -> Server layout -> Application serve p config server = toApplication (runRouter (route p config d)) where - d = Delayed r r r (\ _ _ -> Route server) + d = Delayed r r r r (\ _ _ _ -> Route server) r = return (Route ()) diff --git a/servant-server/src/Servant/Server/Internal.hs b/servant-server/src/Servant/Server/Internal.hs index 1b2c19a2..064511d3 100644 --- a/servant-server/src/Servant/Server/Internal.hs +++ b/servant-server/src/Servant/Server/Internal.hs @@ -10,11 +10,13 @@ {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} #include "overlapping-compat.h" module Servant.Server.Internal ( module Servant.Server.Internal + , module Servant.Server.Internal.Auth , module Servant.Server.Internal.Config , module Servant.Server.Internal.Router , module Servant.Server.Internal.RoutingApplication @@ -24,14 +26,16 @@ module Servant.Server.Internal #if !MIN_VERSION_base(4,8,0) import Control.Applicative ((<$>)) #endif -import Control.Monad.Trans.Except (ExceptT) -import qualified Data.ByteString as B -import qualified Data.ByteString.Lazy as BL -import qualified Data.Map as M -import Data.Maybe (fromMaybe, mapMaybe) -import Data.String (fromString) -import Data.String.Conversions (cs, (<>)) -import Data.Text (Text) +import Control.Monad.Trans.Except (ExceptT, runExceptT) +import qualified Data.ByteString as B +import qualified Data.ByteString.Char8 as BC8 +import qualified Data.ByteString.Lazy as BL +import qualified Data.Map as M +import Data.Maybe (fromMaybe, + mapMaybe) +import Data.String (fromString) +import Data.String.Conversions (cs, (<>)) +import Data.Text (Text) import Data.Typeable import GHC.TypeLits (KnownNat, KnownSymbol, natVal, symbolVal) @@ -48,7 +52,7 @@ import Web.HttpApiData.Internal (parseHeaderMaybe, parseQueryParamMaybe, parseUrlPieceMaybe) -import Servant.API ((:<|>) (..), (:>), Capture, +import Servant.API ((:<|>) (..), (:>), AuthProtect, BasicAuth, Capture, Verb, ReflectMethod(reflectMethod), IsSecure(..), Header, QueryFlag, QueryParam, QueryParams, @@ -62,6 +66,7 @@ import Servant.API.ContentTypes (AcceptHeader (..), import Servant.API.ResponseHeaders (GetHeaders, Headers, getHeaders, getResponse) +import Servant.Server.Internal.Auth import Servant.Server.Internal.Config import Servant.Server.Internal.Router import Servant.Server.Internal.RoutingApplication @@ -450,6 +455,28 @@ instance HasServer api config => HasServer (HttpVersion :> api) config where route Proxy config subserver = WithRequest $ \req -> route (Proxy :: Proxy api) config (passToServer subserver $ httpVersion req) +-- | Basic Authentication +instance (KnownSymbol realm, HasServer api config, HasConfigEntry config (BasicAuthCheck (AuthReturnType (BasicAuth realm)))) + => HasServer (BasicAuth realm :> api) config where + type ServerT (BasicAuth realm :> api) m = AuthReturnType (BasicAuth realm) -> ServerT api m + + route Proxy config subserver = WithRequest $ \ request -> + route (Proxy :: Proxy api) config (subserver `addAuthCheck` authCheck request) + where + realm = BC8.pack $ symbolVal (Proxy :: Proxy realm) + basicAuthConfig = getConfigEntry config + authCheck req = runBasicAuth req realm basicAuthConfig + +-- | General Authentication +instance (HasServer api config, HasConfigEntry config (AuthHandler Request (AuthReturnType (AuthProtect tag)))) => HasServer (AuthProtect tag :> api) config where + type ServerT (AuthProtect tag :> api) m = AuthReturnType (AuthProtect tag) -> ServerT api m + + route Proxy config subserver = WithRequest $ \ request -> + route (Proxy :: Proxy api) config (subserver `addAuthCheck` authCheck request) + where + authHandler = unAuthHandler (getConfigEntry config) + authCheck = fmap (either FailFatal Route) . runExceptT . authHandler + pathIsEmpty :: Request -> Bool pathIsEmpty = go . pathInfo where go [] = True diff --git a/servant-server/src/Servant/Server/Internal/Auth.hs b/servant-server/src/Servant/Server/Internal/Auth.hs new file mode 100644 index 00000000..6e15c7a5 --- /dev/null +++ b/servant-server/src/Servant/Server/Internal/Auth.hs @@ -0,0 +1,77 @@ +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} + +module Servant.Server.Internal.Auth where + +import Control.Monad (guard) +import Control.Monad.Trans.Except (ExceptT) +import qualified Data.ByteString as BS +import Data.ByteString.Base64 (decodeLenient) +import Data.Monoid ((<>)) +import Data.Typeable (Typeable) +import Data.Word8 (isSpace, toLower, _colon) +import GHC.Generics +import Network.HTTP.Types (Header) +import Network.Wai (Request, requestHeaders) + +import Servant.Server.Internal.RoutingApplication +import Servant.Server.Internal.ServantErr + +-- * General Auth + +-- | Specify the type of data returned after we've authenticated a request. +-- quite often this is some `User` datatype. +type family AuthReturnType a :: * + +-- | Handlers for AuthProtected resources +newtype AuthHandler r usr = AuthHandler + { unAuthHandler :: r -> ExceptT ServantErr IO usr } + deriving (Generic, Typeable) + +mkAuthHandler :: (r -> ExceptT ServantErr IO usr) -> AuthHandler r usr +mkAuthHandler = AuthHandler + +-- | The result of authentication/authorization +data BasicAuthResult usr + = Unauthorized + | BadPassword + | NoSuchUser + | Authorized usr + deriving (Eq, Show, Read, Generic, Typeable, Functor) + +-- * Basic Auth + +newtype BasicAuthCheck usr = BasicAuthCheck + { unBasicAuthCheck :: BS.ByteString -- Username + -> BS.ByteString -- Password + -> IO (BasicAuthResult usr) + } + deriving (Generic, Typeable, Functor) + +mkBAChallengerHdr :: BS.ByteString -> Header +mkBAChallengerHdr realm = ("WWW-Authenticate", "Basic realm=\"" <> realm <> "\"") + +-- | Find and decode an 'Authorization' header from the request as Basic Auth +decodeBAHdr :: Request -> Maybe (BS.ByteString, BS.ByteString) +decodeBAHdr req = do + ah <- lookup "Authorization" $ requestHeaders req + let (b, rest) = BS.break isSpace ah + guard (BS.map toLower b == "basic") + let decoded = decodeLenient (BS.dropWhile isSpace rest) + let (username, passWithColonAtHead) = BS.break (== _colon) decoded + (_, password) <- BS.uncons passWithColonAtHead + return (username, password) + +runBasicAuth :: Request -> BS.ByteString -> BasicAuthCheck usr -> IO (RouteResult usr) +runBasicAuth req realm (BasicAuthCheck ba) = + case decodeBAHdr req of + Nothing -> plzAuthenticate + Just e -> uncurry ba e >>= \res -> case res of + BadPassword -> plzAuthenticate + NoSuchUser -> plzAuthenticate + Unauthorized -> return $ Fail err403 + Authorized usr -> return $ Route usr + where plzAuthenticate = return $ Fail err401 { errHeaders = [mkBAChallengerHdr realm] } diff --git a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs index 3112c640..33b3cfbd 100644 --- a/servant-server/src/Servant/Server/Internal/RoutingApplication.hs +++ b/servant-server/src/Servant/Server/Internal/RoutingApplication.hs @@ -1,9 +1,11 @@ {-# LANGUAGE CPP #-} {-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE StandaloneDeriving #-} module Servant.Server.Internal.RoutingApplication where @@ -84,6 +86,7 @@ toApplication ra request respond = do -- static routes (can cause 404) -- delayed captures (can cause 404) -- methods (can cause 405) +-- authentication and authorization (can cause 401, 403) -- delayed body (can cause 415, 400) -- accept header (can cause 406) -- @@ -151,36 +154,71 @@ toApplication ra request respond = do -- The accept header check can be performed as the final -- computation in this block. It can cause a 406. -- -data Delayed :: * -> * where - Delayed :: IO (RouteResult a) - -> IO (RouteResult ()) - -> IO (RouteResult b) - -> (a -> b -> RouteResult c) - -> Delayed c +data Delayed c = forall captures auth body. Delayed + { capturesD :: IO (RouteResult captures) + , methodD :: IO (RouteResult ()) + , authD :: IO (RouteResult auth) + , bodyD :: IO (RouteResult body) + , serverD :: (captures -> auth -> body -> RouteResult c) + } instance Functor Delayed where - fmap f (Delayed a b c g) = Delayed a b c ((fmap . fmap . fmap) f g) + fmap f Delayed{..} + = Delayed { capturesD = capturesD + , methodD = methodD + , authD = authD + , bodyD = bodyD + , serverD = (fmap.fmap.fmap.fmap) f serverD + } -- Note [Existential Record Update] -- | Add a capture to the end of the capture block. addCapture :: Delayed (a -> b) -> IO (RouteResult a) -> Delayed b -addCapture (Delayed captures method body server) new = - Delayed (combineRouteResults (,) captures new) method body (\ (x, v) y -> ($ v) <$> server x y) +addCapture Delayed{..} new + = Delayed { capturesD = combineRouteResults (,) capturesD new + , methodD = methodD + , authD = authD + , bodyD = bodyD + , serverD = \ (x, v) y z -> ($ v) <$> serverD x y z + } -- Note [Existential Record Update] -- | Add a method check to the end of the method block. addMethodCheck :: Delayed a -> IO (RouteResult ()) -> Delayed a -addMethodCheck (Delayed captures method body server) new = - Delayed captures (combineRouteResults const method new) body server +addMethodCheck Delayed{..} new + = Delayed { capturesD = capturesD + , methodD = combineRouteResults const methodD new + , authD = authD + , bodyD = bodyD + , serverD = serverD + } -- Note [Existential Record Update] + +-- | Add an auth check to the end of the auth block. +addAuthCheck :: Delayed (a -> b) + -> IO (RouteResult a) + -> Delayed b +addAuthCheck Delayed{..} new + = Delayed { capturesD = capturesD + , methodD = methodD + , authD = combineRouteResults (,) authD new + , bodyD = bodyD + , serverD = \ x (y, v) z -> ($ v) <$> serverD x y z + } -- Note [Existential Record Update] -- | Add a body check to the end of the body block. addBodyCheck :: Delayed (a -> b) -> IO (RouteResult a) -> Delayed b -addBodyCheck (Delayed captures method body server) new = - Delayed captures method (combineRouteResults (,) body new) (\ x (y, v) -> ($ v) <$> server x y) +addBodyCheck Delayed{..} new + = Delayed { capturesD = capturesD + , methodD = methodD + , authD = authD + , bodyD = combineRouteResults (,) bodyD new + , serverD = \ x y (z, v) -> ($ v) <$> serverD x y z + } -- Note [Existential Record Update] + -- | Add an accept header check to the end of the body block. -- The accept header check should occur after the body check, @@ -189,8 +227,13 @@ addBodyCheck (Delayed captures method body server) new = addAcceptCheck :: Delayed a -> IO (RouteResult ()) -> Delayed a -addAcceptCheck (Delayed captures method body server) new = - Delayed captures method (combineRouteResults const body new) server +addAcceptCheck Delayed{..} new + = Delayed { capturesD = capturesD + , methodD = methodD + , authD = authD + , bodyD = combineRouteResults const bodyD new + , serverD = serverD + } -- Note [Existential Record Update] -- | Many combinators extract information that is passed to -- the handler without the possibility of failure. In such a @@ -222,13 +265,17 @@ combineRouteResults f m1 m2 = -- | Run a delayed server. Performs all scheduled operations -- in order, and passes the results from the capture and body -- blocks on to the actual handler. +-- +-- This should only be called once per request; otherwise the guarantees about +-- effect and HTTP error ordering break down. runDelayed :: Delayed a -> IO (RouteResult a) -runDelayed (Delayed captures method body server) = - captures `bindRouteResults` \ c -> - method `bindRouteResults` \ _ -> - body `bindRouteResults` \ b -> - return (server c b) +runDelayed Delayed{..} = + capturesD `bindRouteResults` \ c -> + methodD `bindRouteResults` \ _ -> + authD `bindRouteResults` \ a -> + bodyD `bindRouteResults` \ b -> + return (serverD c a b) -- | Runs a delayed server and the resulting action. -- Takes a continuation that lets us send a response. @@ -247,3 +294,11 @@ runAction action respond k = runDelayed action >>= go >>= respond case e of Left err -> return . Route $ responseServantErr err Right x -> return $! k x + + +{- Note [Existential Record Update] +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Due to GHC issue , we cannot +do the more succint thing - just update the records we actually change. +-} diff --git a/servant-server/test/Servant/ServerSpec.hs b/servant-server/test/Servant/ServerSpec.hs index 04461566..9729428b 100644 --- a/servant-server/test/Servant/ServerSpec.hs +++ b/servant-server/test/Servant/ServerSpec.hs @@ -3,7 +3,6 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -31,14 +30,14 @@ import Network.HTTP.Types (Status (..), hAccept, hContentType, methodHead, methodPatch, methodPost, methodPut, ok200, parseQuery) -import Network.Wai (Application, Request, pathInfo, +import Network.Wai (Application, Request, requestHeaders, pathInfo, queryString, rawQueryString, responseBuilder, responseLBS) -import Network.Wai.Internal (Response (ResponseBuilder)) +import Network.Wai.Internal (Response (ResponseBuilder), requestHeaders) import Network.Wai.Test (defaultRequest, request, runSession, simpleBody, simpleHeaders, simpleStatus) -import Servant.API ((:<|>) (..), (:>), Capture, Delete, +import Servant.API ((:<|>) (..), (:>), AuthProtect, BasicAuth, Capture, Delete, Get, Header (..), Headers, HttpVersion, IsSecure (..), JSON, @@ -48,14 +47,21 @@ import Servant.API ((:<|>) (..), (:>), Capture, Delete, Raw, RemoteHost, ReqBody, StdMethod (..), Verb, addHeader) import Servant.API.Internal.Test.ComprehensiveAPI -import Servant.Server (ServantErr (..), Server, err404, - serve, Config(EmptyConfig)) +import Servant.Server (ServantErr (..), Server, err401, err404, + serve, Config((:.), EmptyConfig)) import Test.Hspec (Spec, context, describe, it, shouldBe, shouldContain) +import qualified Test.Hspec.Wai as THW import Test.Hspec.Wai (get, liftIO, matchHeaders, matchStatus, request, shouldRespondWith, with, (<:>)) +import qualified Test.Hspec.Wai as THW +import Servant.Server.Internal.Auth + (AuthHandler, AuthReturnType, BasicAuthCheck (BasicAuthCheck), + BasicAuthResult (Authorized, Unauthorized), mkAuthHandler) + +import Servant.Server.Internal.Auth import Servant.Server.Internal.RoutingApplication (toApplication, RouteResult(..)) import Servant.Server.Internal.Router @@ -86,6 +92,7 @@ spec = do responseHeadersSpec routerSpec miscCombinatorSpec + authSpec ------------------------------------------------------------------------------ -- * verbSpec {{{ @@ -528,6 +535,53 @@ miscCombinatorSpec = with (return $ serve miscApi EmptyConfig miscServ) $ go "/host" "\"0.0.0.0:0\"" where go path res = Test.Hspec.Wai.get path `shouldRespondWith` res + +-- }}} +------------------------------------------------------------------------------ +-- * Authentication {{{ +------------------------------------------------------------------------------ +type AuthAPI = BasicAuth "foo" :> "basic" :> Get '[JSON] Animal + :<|> AuthProtect "auth" :> "auth" :> Get '[JSON] Animal +authApi :: Proxy AuthAPI +authApi = Proxy +authServer :: Server AuthAPI +authServer = const (return jerry) :<|> const (return tweety) + +type instance AuthReturnType (BasicAuth "foo") = () +type instance AuthReturnType (AuthProtect "auth") = () + +authConfig :: Config '[ BasicAuthCheck () + , AuthHandler Request () + ] +authConfig = + let basicHandler = BasicAuthCheck $ (\usr pass -> + if usr == "servant" && pass == "server" + then return (Authorized ()) + else return Unauthorized + ) + authHandler = (\req -> + if elem ("Auth", "secret") (requestHeaders req) + then return () + else throwE err401 + ) + in basicHandler :. mkAuthHandler authHandler :. EmptyConfig + +authSpec :: Spec +authSpec = do + describe "Servant.API.Auth" $ do + with (return (serve authApi authConfig authServer)) $ do + + context "Basic Authentication" $ do + it "returns with 401 with bad password" $ do + get "/basic" `shouldRespondWith` 401 + it "returns 200 with the right password" $ do + THW.request methodGet "/basic" [("Authorization","Basic c2VydmFudDpzZXJ2ZXI=")] "" `shouldRespondWith` 200 + + context "Custom Auth Protection" $ do + it "returns 401 when missing headers" $ do + get "/auth" `shouldRespondWith` 401 + it "returns 200 with the right header" $ do + THW.request methodGet "/auth" [("Auth","secret")] "" `shouldRespondWith` 200 -- }}} ------------------------------------------------------------------------------ -- * Test data types {{{ diff --git a/servant/servant.cabal b/servant/servant.cabal index 437c9843..c8a21c1f 100644 --- a/servant/servant.cabal +++ b/servant/servant.cabal @@ -27,6 +27,7 @@ library exposed-modules: Servant.API Servant.API.Alternative + Servant.API.Auth Servant.API.Capture Servant.API.ContentTypes Servant.API.Header diff --git a/servant/src/Servant/API.hs b/servant/src/Servant/API.hs index 2da0d4cf..ca2acd89 100644 --- a/servant/src/Servant/API.hs +++ b/servant/src/Servant/API.hs @@ -37,6 +37,9 @@ module Servant.API ( -- * Response Headers module Servant.API.ResponseHeaders, + -- * Authentication + module Servant.API.Auth, + -- * Untyped endpoints module Servant.API.Raw, -- | Plugging in a wai 'Network.Wai.Application', serving directories @@ -51,6 +54,7 @@ module Servant.API ( ) where import Servant.API.Alternative ((:<|>) (..)) +import Servant.API.Auth (BasicAuth, AuthProtect) import Servant.API.Capture (Capture) import Servant.API.ContentTypes (Accept (..), FormUrlEncoded, FromFormUrlEncoded (..), JSON, diff --git a/servant/src/Servant/API/Auth.hs b/servant/src/Servant/API/Auth.hs new file mode 100644 index 00000000..00a11adf --- /dev/null +++ b/servant/src/Servant/API/Auth.hs @@ -0,0 +1,25 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE PolyKinds #-} +module Servant.API.Auth where + +import Data.Typeable (Typeable) +import GHC.TypeLits (Symbol) + + +-- | Combinator for . +-- +-- *IMPORTANT*: Only use Basic Auth over HTTPS! Credentials are not hashed or +-- encrypted. Note also that because the same credentials are sent on every +-- request, Basic Auth is not as secure as some alternatives. +-- +-- In Basic Auth, username and password are base64-encoded and transmitted via +-- the @Authorization@ header. Handshakes are not required, making it +-- relatively efficient. +data BasicAuth (realm :: Symbol) + deriving (Typeable) + +-- | A generalized Authentication combinator. +data AuthProtect (tag :: k) + deriving (Typeable)