diff --git a/examples/hellos/README.md b/examples/hellos/README.md new file mode 100644 index 0000000..9e63ee6 --- /dev/null +++ b/examples/hellos/README.md @@ -0,0 +1,4 @@ +# Hellos example + +This directory contains code for a simple "hello streaming" service which +demonstrates use of various streaming APIs. diff --git a/examples/hellos/hellos-client/Main.hs b/examples/hellos/hellos-client/Main.hs new file mode 100644 index 0000000..bc32909 --- /dev/null +++ b/examples/hellos/hellos-client/Main.hs @@ -0,0 +1,81 @@ +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fno-warn-missing-signatures #-} +{-# OPTIONS_GHC -fno-warn-unused-binds #-} + +import Control.Monad +import qualified Data.ByteString.Lazy as BL +import Data.Protobuf.Wire.Class +import qualified Data.Text as T +import Data.Word +import GHC.Generics (Generic) +import Network.GRPC.LowLevel + +helloSS = MethodName "/hellos.Hellos/HelloSS" +helloCS = MethodName "/hellos.Hellos/HelloCS" + +data SSRqt = SSRqt { ssName :: T.Text, ssNumReplies :: Word32 } deriving (Show, Eq, Ord, Generic) +instance Message SSRqt +data SSRpy = SSRpy { ssGreeting :: T.Text } deriving (Show, Eq, Ord, Generic) +instance Message SSRpy +data CSRqt = CSRqt { csMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +instance Message CSRqt +data CSRpy = CSRpy { csNumRequests :: Word32 } deriving (Show, Eq, Ord, Generic) +instance Message CSRpy + +expect :: (Eq a, Monad m, Show a) => String -> a -> a -> m () +expect ctx ex got + | ex /= got = fail $ ctx ++ " error: expected " ++ show ex ++ ", got " ++ show got + | otherwise = return () + +doHelloSS c = do + rm <- clientRegisterMethodServerStreaming c helloSS + let nr = 10 + pay = SSRqt "server streaming mode" nr + enc = BL.toStrict . toLazyByteString $ pay + eea <- clientReader c rm 5 enc mempty $ \_md recv -> do + n :: Int <- go recv 0 + expect "doHelloSS/cnt" (fromIntegral nr) n + case eea of + Left e -> fail $ "clientReader error: " ++ show e + Right (_, st, _) + | st /= StatusOk -> fail "clientReader: non-OK status" + | otherwise -> return () + where + expay = "Hello there, server streaming mode!" + go recv n = recv >>= \case + Left e -> fail $ "doHelloSS error: " ++ show e + Right Nothing -> return n + Right (Just r) -> case fromByteString r of + Left e -> fail $ "Decoding error: " ++ show e + Right r' -> do + expect "doHelloSS/rpy" expay (ssGreeting r') + go recv (n+1) + +doHelloCS c = do + rm <- clientRegisterMethodClientStreaming c helloCS + let nr = 10 + pay = CSRqt "client streaming payload" + enc = BL.toStrict . toLazyByteString $ pay + eea <- clientWriter c rm 10 mempty $ \send -> + replicateM_ (fromIntegral nr) $ send enc >>= \case + Left e -> fail $ "doHelloCS: send error: " ++ show e + Right{} -> return () + case eea of + Left e -> fail $ "clientWriter error: " ++ show e + Right (Nothing, _, _, _, _) -> fail "clientWriter error: no reply payload" + Right (Just bs, _init, _trail, st, _dtls) + | st /= StatusOk -> fail "clientWriter: non-OK status" + | otherwise -> case fromByteString bs of + Left e -> fail $ "Decoding error: " ++ show e + Right dec -> expect "doHelloCS/cnt" nr (csNumRequests dec) + +highlevelMain = withGRPC $ \g -> + withClient g (ClientConfig "localhost" 50051 []) $ \c -> do + doHelloSS c + doHelloCS c + +main = highlevelMain diff --git a/examples/hellos/hellos-cpp/Makefile b/examples/hellos/hellos-cpp/Makefile new file mode 100644 index 0000000..5153440 --- /dev/null +++ b/examples/hellos/hellos-cpp/Makefile @@ -0,0 +1,115 @@ +# +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +CXX = g++ +CPPFLAGS += -I/usr/local/include -pthread +CXXFLAGS += -std=c++11 +LDFLAGS += -L/usr/local/lib `pkg-config --libs grpc++` -lprotobuf -lpthread +PROTOC = protoc +GRPC_CPP_PLUGIN = grpc_cpp_plugin +GRPC_CPP_PLUGIN_PATH ?= `which $(GRPC_CPP_PLUGIN)` + +PROTOS_PATH = .. + +vpath %.proto $(PROTOS_PATH) + +all: system-check hellos_client hellos_server + +hellos_client: hellos.pb.o hellos.grpc.pb.o hellos_client.o + $(CXX) $^ $(LDFLAGS) -o $@ + +hellos_server: hellos.pb.o hellos.grpc.pb.o hellos_server.o + $(CXX) $^ $(LDFLAGS) -o $@ + +.PRECIOUS: %.grpc.pb.cc +%.grpc.pb.cc: %.proto + $(PROTOC) -I $(PROTOS_PATH) --grpc_out=. --plugin=protoc-gen-grpc=$(GRPC_CPP_PLUGIN_PATH) $< + +.PRECIOUS: %.pb.cc +%.pb.cc: %.proto + $(PROTOC) -I $(PROTOS_PATH) --cpp_out=. $< + +clean: + rm -f *.o *.pb.cc *.pb.h hellos_client hellos_server + + +# The following is to test your system and ensure a smoother experience. +# They are by no means necessary to actually compile a grpc-enabled software. + +PROTOC_CMD = which $(PROTOC) +PROTOC_CHECK_CMD = $(PROTOC) --version | grep -q libprotoc.3 +PLUGIN_CHECK_CMD = which $(GRPC_CPP_PLUGIN) +HAS_PROTOC = $(shell $(PROTOC_CMD) > /dev/null && echo true || echo false) +ifeq ($(HAS_PROTOC),true) +HAS_VALID_PROTOC = $(shell $(PROTOC_CHECK_CMD) 2> /dev/null && echo true || echo false) +endif +HAS_PLUGIN = $(shell $(PLUGIN_CHECK_CMD) > /dev/null && echo true || echo false) + +SYSTEM_OK = false +ifeq ($(HAS_VALID_PROTOC),true) +ifeq ($(HAS_PLUGIN),true) +SYSTEM_OK = true +endif +endif + +system-check: +ifneq ($(HAS_VALID_PROTOC),true) + @echo " DEPENDENCY ERROR" + @echo + @echo "You don't have protoc 3.0.0 installed in your path." + @echo "Please install Google protocol buffers 3.0.0 and its compiler." + @echo "You can find it here:" + @echo + @echo " https://github.com/google/protobuf/releases/tag/v3.0.0-beta-4" + @echo + @echo "Here is what I get when trying to evaluate your version of protoc:" + @echo + -$(PROTOC) --version + @echo + @echo +endif +ifneq ($(HAS_PLUGIN),true) + @echo " DEPENDENCY ERROR" + @echo + @echo "You don't have the grpc c++ protobuf plugin installed in your path." + @echo "Please install grpc. You can find it here:" + @echo + @echo " https://github.com/grpc/grpc" + @echo + @echo "Here is what I get when trying to detect if you have the plugin:" + @echo + -which $(GRPC_CPP_PLUGIN) + @echo + @echo +endif +ifneq ($(SYSTEM_OK),true) + @false +endif diff --git a/examples/hellos/hellos-cpp/README.md b/examples/hellos/hellos-cpp/README.md new file mode 100644 index 0000000..01b4b08 --- /dev/null +++ b/examples/hellos/hellos-cpp/README.md @@ -0,0 +1,23 @@ +# gRPC C++ hellos example + +This directory contains C++ client/server code for the "hellos" streaming mode +example. + +## Building + +Just run make. + +## Usage + +For running the C++ client and server against each other, + +$ make +$ ./hellos_server & +$ ./hellos_client + +For running the C++ client against the Haskell server + +$ stack build +$ make +$ stack exec hellos-server & +$ ./hellos_client diff --git a/examples/hellos/hellos-cpp/hellos_client.cc b/examples/hellos/hellos-cpp/hellos_client.cc new file mode 100644 index 0000000..661e53d --- /dev/null +++ b/examples/hellos/hellos-cpp/hellos_client.cc @@ -0,0 +1,144 @@ +#include +#include +#include +#include + +#include + +#include "hellos.grpc.pb.h" + +using grpc::Channel; +using grpc::ClientContext; +using grpc::ClientReader; +using grpc::ClientWriter; +using grpc::ClientReaderWriter; +using grpc::Status; +using hellos::SSRqt; +using hellos::SSRpy; +using hellos::CSRqt; +using hellos::CSRpy; +using hellos::BiRqtRpy; +using hellos::Hellos; + +static void Die(const std::string& msg) { + std::cerr << "Fatal error: " << msg << std::endl; + exit(1); +} + +static void CheckRPCStatus(const std::string& ctx, Status status) { + if (status.ok()) { + std::cout << ctx << ": RPC successful." << std::endl; + } else { + std::cout << ctx << ": Got failed status code: " << status.error_code() << std::endl; + std::cout << ctx << ": Got failed status msg: " << status.error_message() << std::endl; + Die(ctx + ": RPC failed"); + } +} + +class HellosClient { + public: + HellosClient(std::shared_ptr channel) + : stub_(Hellos::NewStub(channel)) {} + + void DoHelloSS(const std::string& name, unsigned n) { + SSRqt rqt; + rqt.set_name(name); + rqt.set_num_replies(n); + + SSRpy rpy; + ClientContext ctx; + + std::unique_ptr > reader(stub_->HelloSS(&ctx, rqt)); + unsigned rpyCnt = 0; + while (reader->Read(&rpy)) { + ++rpyCnt; + std::string ex("Hello there, " + name + "!"); + if (rpy.greeting() != ex) { + Die("DoHelloSS/rpy: expected payload '" + ex + + "', got '" + rpy.greeting() + "'"); + } + } + Status status = reader->Finish(); + CheckRPCStatus("DoHelloSS", status); + if (rpyCnt != n) + Die("DoHelloSS/cnt: expected " + std::to_string(n) + + "replies, got " + std::to_string(rpyCnt)); + } + + void DoHelloCS(const std::string& pay, unsigned n) { + CSRqt rqt; + rqt.set_message(pay); + + CSRpy rpy; + ClientContext ctx; + + std::unique_ptr > writer(stub_->HelloCS(&ctx, &rpy)); + for (unsigned i = 0; i < n; ++i) { + if (!writer->Write(rqt)) { + // Broken stream + break; + } + } + writer->WritesDone(); + Status status = writer->Finish(); + CheckRPCStatus("DoHelloCS", status); + if (rpy.num_requests() != n) + Die("DoHelloCS/cnt: expected request count " + std::to_string(n) + + ", got " + std::to_string(rpy.num_requests())); + } + + void DoHelloBi(const std::string& pay, unsigned n) { + BiRqtRpy rqt; + rqt.set_message(pay); + ClientContext ctx; + + std::shared_ptr > strm(stub_->HelloBi(&ctx)); + + // Spawn a writer thread which sends rqt to the server n times. + std::thread writer([strm,rqt,n]() { + for(unsigned i = 0; i < n; ++i) { + strm->Write(rqt); + } + strm->WritesDone(); + }); + + // Concurrently, read back echo'd replies from the server until there are no + // more to consume; ensure we get the expected number of responses after + // there's nothing left to read. + BiRqtRpy rpy; + unsigned rpyCnt = 0; + while(strm->Read(&rpy)) { + if (rpy.message() != pay) + Die("DoHelloBi/rpy: expected payload '" + pay + + "', got '" + rpy.message() + "'"); + ++rpyCnt; + } + writer.join(); + + Status status = strm->Finish(); + CheckRPCStatus("DoHelloBi", status); + if (rpyCnt != n) + Die("DoHelloBi/cnt: expected reply count " + std::to_string(n) + + ", got " + std::to_string(rpyCnt)); + } + + private: + std::unique_ptr stub_; +}; + +int main(int argc, char** argv) { + // Instantiate the client. It requires a channel, out of which the actual RPCs + // are created. This channel models a connection to an endpoint (in this case, + // localhost at port 50051). We indicate that the channel isn't authenticated + // (use of InsecureChannelCredentials()). + HellosClient hellos(grpc::CreateChannel( + "localhost:50051", grpc::InsecureChannelCredentials())); + const unsigned n = 100000; + std::cout << "-------------- HelloSS --------------" << std::endl; + hellos.DoHelloSS("server streaming mode", n); + std::cout << "-------------- HelloCS --------------" << std::endl; + hellos.DoHelloCS("client streaming payload", n); + std::cout << "-------------- HelloBi --------------" << std::endl; + hellos.DoHelloBi("bidi payload", n); + return 0; +} diff --git a/examples/hellos/hellos-cpp/hellos_server.cc b/examples/hellos/hellos-cpp/hellos_server.cc new file mode 100644 index 0000000..988b271 --- /dev/null +++ b/examples/hellos/hellos-cpp/hellos_server.cc @@ -0,0 +1,87 @@ +#include +#include +#include + +#include + +#include "hellos.grpc.pb.h" + +using grpc::Server; +using grpc::ServerBuilder; +using grpc::ServerContext; +using grpc::ServerWriter; +using grpc::ServerReader; +using grpc::ServerReaderWriter; +using grpc::Status; +using hellos::SSRqt; +using hellos::SSRpy; +using hellos::CSRqt; +using hellos::CSRpy; +using hellos::BiRqtRpy; +using hellos::Hellos; + +static void Die(const std::string& msg) { + std::cerr << "Fatal error: " << msg << std::endl; + exit(1); +} + +class HellosImpl final : public Hellos::Service { + Status HelloSS(ServerContext* context, + const SSRqt* rqt, + ServerWriter* writer) override { + for (unsigned i = 0; i < rqt->num_replies(); ++i) { + SSRpy rpy; + rpy.set_greeting("Hello there, " + rqt->name() + "!"); + writer->Write(rpy); + } + return Status::OK; + } + + Status HelloCS(ServerContext* context, + ServerReader* reader, + CSRpy* rpy) override { + CSRqt rqt; + unsigned rqtCnt = 0; + std::string ex("client streaming payload"); + while (reader->Read(&rqt)) { + if (rqt.message() != ex) + Die("HelloCS/rpy: expected payload '" + ex + + "', got '" + rqt.message() + "'"); + ++rqtCnt; + } + rpy->set_num_requests(rqtCnt); + return Status::OK; + } + + Status HelloBi(ServerContext* context, + ServerReaderWriter* strm) override { + BiRqtRpy rqt; + while (strm->Read(&rqt)) { + strm->Write(rqt); + } + return Status::OK; + } + +}; + +void RunServer() { + std::string server_address("0.0.0.0:50051"); + HellosImpl service; + ServerBuilder builder; + // Listen on the given address without any authentication mechanism. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + // Register "service" as the instance through which we'll communicate with + // clients. In this case it corresponds to an *synchronous* service. + builder.RegisterService(&service); + // Finally assemble the server. + std::unique_ptr server(builder.BuildAndStart()); + std::cout << "Server listening on " << server_address << std::endl; + // Wait for the server to shutdown. Note that some other thread must be + // responsible for shutting down the server for this call to ever return. + server->Wait(); +} + +int main(int argc, char** argv) { + RunServer(); + return 0; +} diff --git a/examples/hellos/hellos-server/Main.hs b/examples/hellos/hellos-server/Main.hs new file mode 100644 index 0000000..d14256a --- /dev/null +++ b/examples/hellos/hellos-server/Main.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# OPTIONS_GHC -fno-warn-missing-signatures #-} +{-# OPTIONS_GHC -fno-warn-unused-binds #-} + +import Control.Monad +import Data.Function (fix) +import Data.Monoid +import Data.Protobuf.Wire.Class +import qualified Data.Text as T +import Data.Word +import GHC.Generics (Generic) +import Network.GRPC.HighLevel.Server +import qualified Network.GRPC.HighLevel.Server.Unregistered as U +import Network.GRPC.LowLevel + +serverMeta :: MetadataMap +serverMeta = [("test_meta", "test_meta_value")] + +data SSRqt = SSRqt { ssName :: T.Text, ssNumReplies :: Word32 } deriving (Show, Eq, Ord, Generic) +instance Message SSRqt +data SSRpy = SSRpy { ssGreeting :: T.Text } deriving (Show, Eq, Ord, Generic) +instance Message SSRpy +data CSRqt = CSRqt { csMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +instance Message CSRqt +data CSRpy = CSRpy { csNumRequests :: Word32 } deriving (Show, Eq, Ord, Generic) +instance Message CSRpy +data BiRqtRpy = BiRqtRpy { biMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +instance Message BiRqtRpy + +expect :: (Eq a, Monad m, Show a) => String -> a -> a -> m () +expect ctx ex got + | ex /= got = fail $ ctx ++ " error: expected " ++ show ex ++ ", got " ++ show got + | otherwise = return () + +helloSS :: Handler 'ServerStreaming +helloSS = ServerStreamHandler "/hellos.Hellos/HelloSS" $ \sc send -> do + let SSRqt{..} = payload sc + replicateM_ (fromIntegral ssNumReplies) $ do + eea <- send $ SSRpy $ "Hello there, " <> ssName <> "!" + case eea of + Left e -> fail $ "helloSS error: " ++ show e + Right{} -> return () + return (serverMeta, StatusOk, StatusDetails "helloSS response details") + +helloCS :: Handler 'ClientStreaming +helloCS = ClientStreamHandler "/hellos.Hellos/HelloCS" $ \_ recv -> flip fix 0 $ \go n -> + recv >>= \case + Left e -> fail $ "helloCS error: " ++ show e + Right Nothing -> return (Just (CSRpy n), mempty, StatusOk, StatusDetails "helloCS details") + Right (Just rqt) -> do + expect "helloCS" "client streaming payload" (csMessage rqt) + go (n+1) + +helloBi :: Handler 'BiDiStreaming +helloBi = BiDiStreamHandler "/hellos.Hellos/HelloBi" $ \_ recv send -> fix $ \go -> + recv >>= \case + Left e -> fail $ "helloBi recv error: " ++ show e + Right Nothing -> return (mempty, StatusOk, StatusDetails "helloBi details") + Right (Just rqt) -> do + expect "helloBi" "bidi payload" (biMessage rqt) + send rqt >>= \case + Left e -> fail $ "helloBi send error: " ++ show e + _ -> go + +highlevelMainUnregistered :: IO () +highlevelMainUnregistered = + U.serverLoop defaultOptions{ + optServerStreamHandlers = [helloSS] + , optClientStreamHandlers = [helloCS] + , optBiDiStreamHandlers = [helloBi] + } + +main :: IO () +main = highlevelMainUnregistered + +defConfig :: ServerConfig +defConfig = ServerConfig "localhost" 50051 [] [] [] [] [] diff --git a/examples/hellos/hellos.proto b/examples/hellos/hellos.proto new file mode 100644 index 0000000..3225b3a --- /dev/null +++ b/examples/hellos/hellos.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; +package hellos; + +service Hellos { + + // Server streaming: Request n repetitions of a greeting be sent, based off of a given name + rpc HelloSS(SSRqt) returns (stream SSRpy) {} + + // Client streaming: Send n requests and receive a total request count when done + rpc HelloCS(stream CSRqt) returns (CSRpy) {} + + // Simple bidi streaming: ping-pong echo + rpc HelloBi(stream BiRqtRpy) returns (stream BiRqtRpy) {} +} + +message SSRqt { + string name = 1; + uint32 num_replies = 2; +} + +message SSRpy { + string greeting = 1; +} + +message CSRqt { + string message = 1; +} + +message CSRpy { + uint32 num_requests = 1; +} + +message BiRqtRpy { + string message = 1; +} diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index ba9a961..eee67ee 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -90,6 +90,44 @@ library CPP-Options: -DDEBUG CC-Options: -DGRPC_HASKELL_DEBUG +executable hellos-server + if flag(with-examples) + build-depends: + base ==4.8.* + , async + , bytestring == 0.10.* + , containers ==0.5.* + , grpc-haskell + , proto3-wire + , protobuf-wire + , text + , transformers + else + buildable: False + default-language: Haskell2010 + ghc-options: -Wall -g -threaded -rtsopts -with-rtsopts=-N -O2 + hs-source-dirs: examples/hellos/hellos-server + main-is: Main.hs + +executable hellos-client + if flag(with-examples) + build-depends: + base ==4.8.* + , async + , bytestring == 0.10.* + , containers ==0.5.* + , grpc-haskell + , proto3-wire + , protobuf-wire + , text + , transformers + else + buildable: False + default-language: Haskell2010 + ghc-options: -Wall -g -threaded -rtsopts -with-rtsopts=-N -O2 + hs-source-dirs: examples/hellos/hellos-client + main-is: Main.hs + executable echo-server if flag(with-examples) build-depends: diff --git a/src/Network/GRPC/HighLevel/Server.hs b/src/Network/GRPC/HighLevel/Server.hs index b6d1cb9..be22278 100644 --- a/src/Network/GRPC/HighLevel/Server.hs +++ b/src/Network/GRPC/HighLevel/Server.hs @@ -1,14 +1,15 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} module Network.GRPC.HighLevel.Server where import Control.Concurrent.Async +import qualified Control.Exception as CE import Control.Monad import Data.ByteString (ByteString) import qualified Data.ByteString.Lazy as BL @@ -16,6 +17,7 @@ import Data.Protobuf.Wire.Class import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U +import System.IO type ServerHandler a b = ServerCall a @@ -25,7 +27,7 @@ convertServerHandler :: (Message a, Message b) => ServerHandler a b -> ServerHandlerLL convertServerHandler f c = case fromByteString (payload c) of - Left x -> error $ "Failed to deserialize message: " ++ show x + Left x -> CE.throw (GRPCIODecodeError x) Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c) return (toBS y, tm, sc, sd) @@ -54,7 +56,7 @@ convertServerWriterHandler f c send = f (convert <$> c) (convertSend send) where convert bs = case fromByteString bs of - Left x -> error $ "deserialization error: " ++ show x -- TODO FIXME + Left x -> CE.throw (GRPCIODecodeError x) Right x -> x type ServerRWHandler a b = @@ -121,14 +123,8 @@ handlerMethodName (ClientStreamHandler m _) = m handlerMethodName (ServerStreamHandler m _) = m handlerMethodName (BiDiStreamHandler m _) = m --- TODO: find some idiomatic way to do logging that doesn't force the user --- into anything they don't want. -logShow :: Show a => a -> IO () -logShow = print - -logAskReport :: Show a => a -> IO () -logAskReport x = - logShow $ show x ++ " This probably indicates a bug in gRPC-haskell. Please report this error." +logMsg :: String -> IO () +logMsg = hPutStrLn stderr -- | Handles errors that result from trying to handle a call on the server. -- For each error, takes a different action depending on the severity in the @@ -137,12 +133,17 @@ logAskReport x = handleCallError :: Either GRPCIOError a -> IO () handleCallError (Right _) = return () handleCallError (Left GRPCIOTimeout) = - --Probably a benign timeout (such as a client disappearing), noop for now. + -- Probably a benign timeout (such as a client disappearing), noop for now. return () handleCallError (Left GRPCIOShutdown) = - --Server shutting down. Benign. + -- Server shutting down. Benign. return () -handleCallError (Left x) = logAskReport x +handleCallError (Left (GRPCIODecodeError e)) = + logMsg $ "Decoding error: " ++ show e +handleCallError (Left (GRPCIOHandlerException e)) = + logMsg $ "Handler exception caught: " ++ show e +handleCallError (Left x) = + logMsg $ show x ++ ": This probably indicates a bug in gRPC-haskell. Please report this error." loopWError :: Int -> IO (Either GRPCIOError a) @@ -157,9 +158,7 @@ handleLoop :: Server -> (Handler a, RegisteredMethod a) -> IO () handleLoop s (UnaryHandler _ f, rm) = - loopWError 0 $ do - --grpcDebug' "handleLoop about to block on serverHandleNormalCall" - serverHandleNormalCall s rm mempty $ convertServerHandler f + loopWError 0 $ serverHandleNormalCall s rm mempty $ convertServerHandler f handleLoop s (ClientStreamHandler _ f, rm) = loopWError 0 $ serverReader s rm mempty $ convertServerReaderHandler f handleLoop s (ServerStreamHandler _ f, rm) = @@ -229,6 +228,6 @@ serverLoop opts = unknownHandler s = --TODO: is this working? U.serverHandleNormalCall s mempty $ \call _ -> do - logShow $ "Requested unknown endpoint: " ++ show (U.callMethod call) + logMsg $ "Requested unknown endpoint: " ++ show (U.callMethod call) return ("", mempty, StatusNotFound, StatusDetails "Unknown method") diff --git a/src/Network/GRPC/HighLevel/Server/Unregistered.hs b/src/Network/GRPC/HighLevel/Server/Unregistered.hs index 2b83d88..0ad8990 100644 --- a/src/Network/GRPC/HighLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/HighLevel/Server/Unregistered.hs @@ -1,18 +1,22 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} module Network.GRPC.HighLevel.Server.Unregistered where +import Control.Arrow +import qualified Control.Exception as CE import Control.Monad -import Data.Protobuf.Wire.Class import Data.Foldable (find) +import Data.Protobuf.Wire.Class import Network.GRPC.HighLevel.Server import Network.GRPC.LowLevel +import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U -import qualified Network.GRPC.LowLevel.Call.Unregistered as U dispatchLoop :: Server -> MetadataMap @@ -41,7 +45,10 @@ dispatchLoop server meta hN hC hS hB = , mempty , StatusNotFound , StatusDetails "unknown method") - handleError f = f >>= handleCallError + + handleError = (handleCallError . left herr =<<) . CE.try + where herr (e :: CE.SomeException) = GRPCIOHandlerException (show e) + unaryHandler :: (Message a, Message b) => U.ServerCall -> ServerHandler a b diff --git a/src/Network/GRPC/LowLevel/GRPC.hs b/src/Network/GRPC/LowLevel/GRPC.hs index 9d79717..c15dbfe 100644 --- a/src/Network/GRPC/LowLevel/GRPC.hs +++ b/src/Network/GRPC/LowLevel/GRPC.hs @@ -9,6 +9,7 @@ import Control.Exception import Data.String (IsString) import qualified Data.ByteString as B import qualified Data.Map as M +import Data.Typeable import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Op as C import Proto3.Wire.Decode (ParseError) @@ -27,7 +28,8 @@ withGRPC :: (GRPC -> IO a) -> IO a withGRPC = bracket (C.grpcInit >> return GRPC) (\_ -> grpcDebug "withGRPC: shutting down" >> C.grpcShutdown) --- | Describes all errors that can occur while running a GRPC-related IO action. +-- | Describes all errors that can occur while running a GRPC-related IO +-- action. data GRPCIOError = GRPCIOCallError C.CallError -- ^ Errors that can occur while the call is in flight. These -- errors come from the core gRPC library directly. @@ -43,10 +45,12 @@ data GRPCIOError = GRPCIOCallError C.CallError -- reasonable amount of time. | GRPCIOUnknownError | GRPCIOBadStatusCode C.StatusCode StatusDetails + | GRPCIODecodeError ParseError - | GRPCIOInternalMissingExpectedPayload | GRPCIOInternalUnexpectedRecv String -- debugging description - deriving (Show, Eq) + | GRPCIOHandlerException String + deriving (Eq, Show, Typeable) +instance Exception GRPCIOError throwIfCallError :: C.CallError -> Either GRPCIOError () throwIfCallError C.CallOk = Right () diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs index 22349fd..e17ee1f 100644 --- a/src/Network/GRPC/LowLevel/Server.hs +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -27,10 +27,9 @@ import Control.Concurrent.STM.TVar (TVar , readTVarIO , newTVarIO) import Control.Exception (bracket, finally) -import Control.Monad hiding (mapM_) +import Control.Monad import Control.Monad.Trans.Except import Data.ByteString (ByteString) -import Data.Foldable (mapM_) import qualified Data.Set as S import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, @@ -149,7 +148,7 @@ startServer grpc conf@ServerConfig{..} = stopServer :: Server -> IO () -- TODO: Do method handles need to be freed? -stopServer server@Server{ unsafeServer = s, .. } = do +stopServer Server{ unsafeServer = s, .. } = do grpcDebug "stopServer: calling shutdownNotify." shutdownNotify serverCQ grpcDebug "stopServer: cancelling all calls." diff --git a/src/Network/GRPC/LowLevel/Server/Unregistered.hs b/src/Network/GRPC/LowLevel/Server/Unregistered.hs index bd29312..c65ffd6 100644 --- a/src/Network/GRPC/LowLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Server/Unregistered.hs @@ -7,7 +7,6 @@ import Control.Exception (finally) import Control.Monad.Trans.Except import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call.Unregistered -import Network.GRPC.LowLevel.CompletionQueue (createCompletionQueue) import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.Op (Op (..) @@ -29,7 +28,7 @@ import qualified Network.GRPC.Unsafe.Op as C serverCreateCall :: Server -> IO (Either GRPCIOError ServerCall) -serverCreateCall Server{..} = do +serverCreateCall Server{..} = serverRequestCall unsafeServer serverCQ serverCallCQ withServerCall :: Server diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs index 5c3b6f0..e3d1493 100644 --- a/tests/LowLevelTests.hs +++ b/tests/LowLevelTests.hs @@ -1,10 +1,14 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} module LowLevelTests where @@ -20,7 +24,6 @@ import Network.GRPC.LowLevel import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Client.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U -import Pipes ((>->)) import qualified Pipes as P import Test.Tasty import Test.Tasty.HUnit as HU (Assertion, @@ -105,7 +108,7 @@ testMixRegisteredUnregistered = return () where regThread = do let rm = head (normalMethods s) - r <- serverHandleNormalCall s rm dummyMeta $ \c -> do + _r <- serverHandleNormalCall s rm dummyMeta $ \c -> do payload c @?= "Hello" return ("reply test", dummyMeta, StatusOk, "") return () @@ -284,11 +287,10 @@ testBiDiStreaming = trailMD = dummyMeta serverStatus = StatusOk serverDtls = "deets" - is act x = act >>= liftIO . (@?= x) client c = do rm <- clientRegisterMethodBiDiStreaming c "/bidi" - eea <- clientRW c rm 10 clientInitMD $ \initMD recv send -> do + eea <- clientRW c rm 10 clientInitMD $ \_initMD recv send -> do send "cw0" `is` Right () recv `is` Right (Just "sw0") send "cw1" `is` Right () @@ -320,11 +322,10 @@ testBiDiStreamingUnregistered = trailMD = dummyMeta serverStatus = StatusOk serverDtls = "deets" - is act x = act >>= liftIO . (@?= x) client c = do rm <- clientRegisterMethodBiDiStreaming c "/bidi" - eea <- clientRW c rm 10 clientInitMD $ \initMD recv send -> do + eea <- clientRW c rm 10 clientInitMD $ \_initMD recv send -> do send "cw0" `is` Right () recv `is` Right (Just "sw0") send "cw1" `is` Right () @@ -389,7 +390,7 @@ testGoaway = rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 10 "" mempty clientRequest c rm 10 "" mempty - lastResult <- clientRequest c rm 1 "" mempty + lastResult <- clientRequest c rm 1 "" mempty assertBool "Client handles server shutdown gracefully" $ lastResult == badStatus StatusUnavailable || @@ -423,8 +424,7 @@ testServerCallExpirationCheck = where client c = do rm <- clientRegisterMethodNormal c "/foo" - result <- clientRequest c rm 3 "" mempty - return () + void $ clientRequest c rm 3 "" mempty server s = do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \c -> do @@ -447,12 +447,11 @@ testCustomUserAgent = client = TestClient (ClientConfig "localhost" 50051 clientArgs) $ \c -> do rm <- clientRegisterMethodNormal c "/foo" - result <- clientRequest c rm 4 "" mempty - return () + void $ clientRequest c rm 4 "" mempty server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \c -> do - let ua = (metadata c) M.! "user-agent" + let ua = metadata c M.! "user-agent" assertBool "User agent prefix is present" $ isPrefixOf "prefix!" ua assertBool "User agent suffix is present" $ isSuffixOf "suffix!" ua return dummyResp @@ -468,8 +467,7 @@ testClientCompression = 50051 [CompressionAlgArg GrpcCompressDeflate]) $ \c -> do rm <- clientRegisterMethodNormal c "/foo" - result <- clientRequest c rm 1 "hello" mempty - return () + void $ clientRequest c rm 1 "hello" mempty server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \c -> do diff --git a/tests/UnsafeTests.hs b/tests/UnsafeTests.hs index 29593dc..54686f3 100644 --- a/tests/UnsafeTests.hs +++ b/tests/UnsafeTests.hs @@ -3,25 +3,20 @@ module UnsafeTests (unsafeTests) where import Control.Concurrent (threadDelay) -import Control.Concurrent.Async import Control.Exception (bracket_) import Control.Monad import qualified Data.ByteString as B import Foreign.Marshal.Alloc -import Foreign.Ptr import Foreign.Storable import Network.GRPC.Unsafe import Network.GRPC.Unsafe.ByteBuffer -import Network.GRPC.Unsafe.Constants import Network.GRPC.Unsafe.Metadata -import Network.GRPC.Unsafe.Op import Network.GRPC.Unsafe.Slice import Network.GRPC.Unsafe.Time import Network.GRPC.Unsafe.ChannelArgs import System.Clock import Test.Tasty -import Test.Tasty.HUnit as HU (testCase, (@?=), - assertBool) +import Test.Tasty.HUnit as HU (testCase, (@?=)) unsafeTests :: TestTree unsafeTests = testGroup "Unit tests for unsafe C bindings" @@ -133,3 +128,6 @@ grpc = bracket_ grpcInit grpcShutdown . void threadDelaySecs :: Int -> IO () threadDelaySecs = threadDelay . (* 10^(6::Int)) + +_nowarnUnused :: a +_nowarnUnused = assertCqEventComplete `undefined` threadDelaySecs