diff --git a/examples/echo/echo-client/Main.hs b/examples/echo/echo-client/Main.hs index e2e0c2a..fc71a54 100644 --- a/examples/echo/echo-client/Main.hs +++ b/examples/echo/echo-client/Main.hs @@ -1,21 +1,53 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} import Control.Monad +import qualified Data.ByteString.Lazy as BL +import Data.Protobuf.Wire.Class +import qualified Data.Text as T +import Data.Word +import GHC.Generics (Generic) import Network.GRPC.LowLevel -import Network.GRPC.LowLevel.Call import qualified Network.GRPC.LowLevel.Client.Unregistered as U -import System.Environment echoMethod = MethodName "/echo.Echo/DoEcho" _unregistered c = U.clientRequest c echoMethod 1 "hi" mempty -main = withGRPC $ \g -> +regMain = withGRPC $ \g -> withClient g (ClientConfig "localhost" 50051 []) $ \c -> do - rm <- clientRegisterMethod c echoMethod Normal + rm <- clientRegisterMethodNormal c echoMethod replicateM_ 100000 $ clientRequest c rm 5 "hi" mempty >>= \case Left e -> error $ "Got client error: " ++ show e - _ -> return () + Right r + | rspBody r == "hi" -> return () + | otherwise -> error $ "Got unexpected payload: " ++ show r + +-- NB: If you change these, make sure to change them in the server as well. +-- TODO: Put these in a common location (or just hack around it until CG is working) +data EchoRequest = EchoRequest {message :: T.Text} deriving (Show, Eq, Ord, Generic) +instance Message EchoRequest +data AddRequest = AddRequest {addX :: Word32, addY :: Word32} deriving (Show, Eq, Ord, Generic) +instance Message AddRequest +data AddResponse = AddResponse {answer :: Word32} deriving (Show, Eq, Ord, Generic) +instance Message AddResponse + +-- TODO: Create Network.GRPC.HighLevel.Client w/ request variants + +highlevelMain = withGRPC $ \g -> + withClient g (ClientConfig "localhost" 50051 []) $ \c -> do + rm <- clientRegisterMethodNormal c echoMethod + let pay = EchoRequest "hi" + enc = BL.toStrict . toLazyByteString $ pay + replicateM_ 1 $ clientRequest c rm 5 enc mempty >>= \case + Left e -> error $ "Got client error: " ++ show e + Right r -> case fromByteString (rspBody r) of + Left e -> error $ "Got decoding error: " ++ show e + Right dec + | dec == pay -> return () + | otherwise -> error $ "Got unexpected payload: " ++ show dec + +main = highlevelMain diff --git a/examples/echo/echo-cpp/echo-client.cc b/examples/echo/echo-cpp/echo-client.cc index 13ad30d..cda410e 100644 --- a/examples/echo/echo-cpp/echo-client.cc +++ b/examples/echo/echo-cpp/echo-client.cc @@ -6,12 +6,10 @@ #include "echo.grpc.pb.h" using namespace std; - +using namespace echo; using grpc::Channel; using grpc::ClientContext; using grpc::Status; -using echo::EchoRequest; -using echo::Echo; class EchoClient { public: @@ -32,7 +30,29 @@ private: unique_ptr stub_; }; +class AddClient { +public: + AddClient(shared_ptr chan) : stub_(Add::NewStub(chan)) {} + + AddResponse DoAdd(const uint32_t x, const uint32_t y){ + AddRequest msg; + msg.set_addx(x); + msg.set_addy(y); + + AddResponse resp; + + ClientContext ctx; + + stub_->DoAdd(&ctx, msg, &resp); + + return resp; + } +private: + unique_ptr stub_; +}; + int main(){ + /* EchoClient client(grpc::CreateChannel("localhost:50051", grpc::InsecureChannelCredentials())); string msg("hi"); @@ -43,6 +63,11 @@ int main(){ return 1; } } +*/ + AddClient client (grpc::CreateChannel("localhost:50051", + grpc::InsecureChannelCredentials())); + AddResponse answer = client.DoAdd(1,2); + cout<<"Got answer: "<& channel) } Echo::Service::Service() { + (void)Echo_method_names; AddMethod(new ::grpc::RpcServiceMethod( Echo_method_names[0], ::grpc::RpcMethod::NORMAL_RPC, @@ -55,5 +56,46 @@ Echo::Service::~Service() { } +static const char* Add_method_names[] = { + "/echo.Add/DoAdd", +}; + +std::unique_ptr< Add::Stub> Add::NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options) { + std::unique_ptr< Add::Stub> stub(new Add::Stub(channel)); + return stub; +} + +Add::Stub::Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel) + : channel_(channel), rpcmethod_DoAdd_(Add_method_names[0], ::grpc::RpcMethod::NORMAL_RPC, channel) + {} + +::grpc::Status Add::Stub::DoAdd(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::echo::AddResponse* response) { + return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_DoAdd_, context, request, response); +} + +::grpc::ClientAsyncResponseReader< ::echo::AddResponse>* Add::Stub::AsyncDoAddRaw(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::grpc::CompletionQueue* cq) { + return new ::grpc::ClientAsyncResponseReader< ::echo::AddResponse>(channel_.get(), cq, rpcmethod_DoAdd_, context, request); +} + +Add::Service::Service() { + (void)Add_method_names; + AddMethod(new ::grpc::RpcServiceMethod( + Add_method_names[0], + ::grpc::RpcMethod::NORMAL_RPC, + new ::grpc::RpcMethodHandler< Add::Service, ::echo::AddRequest, ::echo::AddResponse>( + std::mem_fn(&Add::Service::DoAdd), this))); +} + +Add::Service::~Service() { +} + +::grpc::Status Add::Service::DoAdd(::grpc::ServerContext* context, const ::echo::AddRequest* request, ::echo::AddResponse* response) { + (void) context; + (void) request; + (void) response; + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); +} + + } // namespace echo diff --git a/examples/echo/echo-cpp/echo.grpc.pb.h b/examples/echo/echo-cpp/echo.grpc.pb.h index df1c128..c22a4a8 100644 --- a/examples/echo/echo-cpp/echo.grpc.pb.h +++ b/examples/echo/echo-cpp/echo.grpc.pb.h @@ -17,6 +17,7 @@ namespace grpc { class CompletionQueue; +class Channel; class RpcService; class ServerCompletionQueue; class ServerContext; @@ -60,7 +61,7 @@ class Echo GRPC_FINAL { template class WithAsyncMethod_DoEcho : public BaseClass { private: - void BaseClassMustBeDerivedFromService(Service *service) {} + void BaseClassMustBeDerivedFromService(const Service *service) {} public: WithAsyncMethod_DoEcho() { ::grpc::Service::MarkMethodAsync(0); @@ -81,7 +82,7 @@ class Echo GRPC_FINAL { template class WithGenericMethod_DoEcho : public BaseClass { private: - void BaseClassMustBeDerivedFromService(Service *service) {} + void BaseClassMustBeDerivedFromService(const Service *service) {} public: WithGenericMethod_DoEcho() { ::grpc::Service::MarkMethodGeneric(0); @@ -97,6 +98,79 @@ class Echo GRPC_FINAL { }; }; +class Add GRPC_FINAL { + public: + class StubInterface { + public: + virtual ~StubInterface() {} + virtual ::grpc::Status DoAdd(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::echo::AddResponse* response) = 0; + std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::echo::AddResponse>> AsyncDoAdd(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::grpc::CompletionQueue* cq) { + return std::unique_ptr< ::grpc::ClientAsyncResponseReaderInterface< ::echo::AddResponse>>(AsyncDoAddRaw(context, request, cq)); + } + private: + virtual ::grpc::ClientAsyncResponseReaderInterface< ::echo::AddResponse>* AsyncDoAddRaw(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::grpc::CompletionQueue* cq) = 0; + }; + class Stub GRPC_FINAL : public StubInterface { + public: + Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel); + ::grpc::Status DoAdd(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::echo::AddResponse* response) GRPC_OVERRIDE; + std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::echo::AddResponse>> AsyncDoAdd(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::grpc::CompletionQueue* cq) { + return std::unique_ptr< ::grpc::ClientAsyncResponseReader< ::echo::AddResponse>>(AsyncDoAddRaw(context, request, cq)); + } + + private: + std::shared_ptr< ::grpc::ChannelInterface> channel_; + ::grpc::ClientAsyncResponseReader< ::echo::AddResponse>* AsyncDoAddRaw(::grpc::ClientContext* context, const ::echo::AddRequest& request, ::grpc::CompletionQueue* cq) GRPC_OVERRIDE; + const ::grpc::RpcMethod rpcmethod_DoAdd_; + }; + static std::unique_ptr NewStub(const std::shared_ptr< ::grpc::ChannelInterface>& channel, const ::grpc::StubOptions& options = ::grpc::StubOptions()); + + class Service : public ::grpc::Service { + public: + Service(); + virtual ~Service(); + virtual ::grpc::Status DoAdd(::grpc::ServerContext* context, const ::echo::AddRequest* request, ::echo::AddResponse* response); + }; + template + class WithAsyncMethod_DoAdd : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + WithAsyncMethod_DoAdd() { + ::grpc::Service::MarkMethodAsync(0); + } + ~WithAsyncMethod_DoAdd() GRPC_OVERRIDE { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status DoAdd(::grpc::ServerContext* context, const ::echo::AddRequest* request, ::echo::AddResponse* response) GRPC_FINAL GRPC_OVERRIDE { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + void RequestDoAdd(::grpc::ServerContext* context, ::echo::AddRequest* request, ::grpc::ServerAsyncResponseWriter< ::echo::AddResponse>* response, ::grpc::CompletionQueue* new_call_cq, ::grpc::ServerCompletionQueue* notification_cq, void *tag) { + ::grpc::Service::RequestAsyncUnary(0, context, request, response, new_call_cq, notification_cq, tag); + } + }; + typedef WithAsyncMethod_DoAdd AsyncService; + template + class WithGenericMethod_DoAdd : public BaseClass { + private: + void BaseClassMustBeDerivedFromService(const Service *service) {} + public: + WithGenericMethod_DoAdd() { + ::grpc::Service::MarkMethodGeneric(0); + } + ~WithGenericMethod_DoAdd() GRPC_OVERRIDE { + BaseClassMustBeDerivedFromService(this); + } + // disable synchronous version of this method + ::grpc::Status DoAdd(::grpc::ServerContext* context, const ::echo::AddRequest* request, ::echo::AddResponse* response) GRPC_FINAL GRPC_OVERRIDE { + abort(); + return ::grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + } + }; +}; + } // namespace echo diff --git a/examples/echo/echo-cpp/echo.pb.cc b/examples/echo/echo-cpp/echo.pb.cc index c68cac1..9ecfa3b 100644 --- a/examples/echo/echo-cpp/echo.pb.cc +++ b/examples/echo/echo-cpp/echo.pb.cc @@ -24,6 +24,12 @@ namespace { const ::google::protobuf::Descriptor* EchoRequest_descriptor_ = NULL; const ::google::protobuf::internal::GeneratedMessageReflection* EchoRequest_reflection_ = NULL; +const ::google::protobuf::Descriptor* AddRequest_descriptor_ = NULL; +const ::google::protobuf::internal::GeneratedMessageReflection* + AddRequest_reflection_ = NULL; +const ::google::protobuf::Descriptor* AddResponse_descriptor_ = NULL; +const ::google::protobuf::internal::GeneratedMessageReflection* + AddResponse_reflection_ = NULL; } // namespace @@ -49,6 +55,37 @@ void protobuf_AssignDesc_echo_2eproto() { sizeof(EchoRequest), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(EchoRequest, _internal_metadata_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(EchoRequest, _is_default_instance_)); + AddRequest_descriptor_ = file->message_type(1); + static const int AddRequest_offsets_[2] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddRequest, addx_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddRequest, addy_), + }; + AddRequest_reflection_ = + ::google::protobuf::internal::GeneratedMessageReflection::NewGeneratedMessageReflection( + AddRequest_descriptor_, + AddRequest::default_instance_, + AddRequest_offsets_, + -1, + -1, + -1, + sizeof(AddRequest), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddRequest, _internal_metadata_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddRequest, _is_default_instance_)); + AddResponse_descriptor_ = file->message_type(2); + static const int AddResponse_offsets_[1] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddResponse, answer_), + }; + AddResponse_reflection_ = + ::google::protobuf::internal::GeneratedMessageReflection::NewGeneratedMessageReflection( + AddResponse_descriptor_, + AddResponse::default_instance_, + AddResponse_offsets_, + -1, + -1, + -1, + sizeof(AddResponse), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddResponse, _internal_metadata_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddResponse, _is_default_instance_)); } namespace { @@ -63,6 +100,10 @@ void protobuf_RegisterTypes(const ::std::string&) { protobuf_AssignDescriptorsOnce(); ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( EchoRequest_descriptor_, &EchoRequest::default_instance()); + ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( + AddRequest_descriptor_, &AddRequest::default_instance()); + ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( + AddResponse_descriptor_, &AddResponse::default_instance()); } } // namespace @@ -70,6 +111,10 @@ void protobuf_RegisterTypes(const ::std::string&) { void protobuf_ShutdownFile_echo_2eproto() { delete EchoRequest::default_instance_; delete EchoRequest_reflection_; + delete AddRequest::default_instance_; + delete AddRequest_reflection_; + delete AddResponse::default_instance_; + delete AddResponse_reflection_; } void protobuf_AddDesc_echo_2eproto() { @@ -80,12 +125,20 @@ void protobuf_AddDesc_echo_2eproto() { ::google::protobuf::DescriptorPool::InternalAddGeneratedFile( "\n\necho.proto\022\004echo\"\036\n\013EchoRequest\022\017\n\007mes" - "sage\030\001 \001(\t28\n\004Echo\0220\n\006DoEcho\022\021.echo.Echo" - "Request\032\021.echo.EchoRequest\"\000b\006proto3", 116); + "sage\030\001 \001(\t\"(\n\nAddRequest\022\014\n\004addX\030\001 \001(\007\022\014" + "\n\004addY\030\002 \001(\007\"\035\n\013AddResponse\022\016\n\006answer\030\001 " + "\001(\00728\n\004Echo\0220\n\006DoEcho\022\021.echo.EchoRequest" + "\032\021.echo.EchoRequest\"\00025\n\003Add\022.\n\005DoAdd\022\020." + "echo.AddRequest\032\021.echo.AddResponse\"\000b\006pr" + "oto3", 244); ::google::protobuf::MessageFactory::InternalRegisterGeneratedFile( "echo.proto", &protobuf_RegisterTypes); EchoRequest::default_instance_ = new EchoRequest(); + AddRequest::default_instance_ = new AddRequest(); + AddResponse::default_instance_ = new AddResponse(); EchoRequest::default_instance_->InitAsDefaultInstance(); + AddRequest::default_instance_->InitAsDefaultInstance(); + AddResponse::default_instance_->InitAsDefaultInstance(); ::google::protobuf::internal::OnShutdown(&protobuf_ShutdownFile_echo_2eproto); } @@ -377,6 +430,516 @@ void EchoRequest::clear_message() { #endif // PROTOBUF_INLINE_NOT_IN_HEADERS +// =================================================================== + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int AddRequest::kAddXFieldNumber; +const int AddRequest::kAddYFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +AddRequest::AddRequest() + : ::google::protobuf::Message(), _internal_metadata_(NULL) { + SharedCtor(); + // @@protoc_insertion_point(constructor:echo.AddRequest) +} + +void AddRequest::InitAsDefaultInstance() { + _is_default_instance_ = true; +} + +AddRequest::AddRequest(const AddRequest& from) + : ::google::protobuf::Message(), + _internal_metadata_(NULL) { + SharedCtor(); + MergeFrom(from); + // @@protoc_insertion_point(copy_constructor:echo.AddRequest) +} + +void AddRequest::SharedCtor() { + _is_default_instance_ = false; + _cached_size_ = 0; + addx_ = 0u; + addy_ = 0u; +} + +AddRequest::~AddRequest() { + // @@protoc_insertion_point(destructor:echo.AddRequest) + SharedDtor(); +} + +void AddRequest::SharedDtor() { + if (this != default_instance_) { + } +} + +void AddRequest::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const ::google::protobuf::Descriptor* AddRequest::descriptor() { + protobuf_AssignDescriptorsOnce(); + return AddRequest_descriptor_; +} + +const AddRequest& AddRequest::default_instance() { + if (default_instance_ == NULL) protobuf_AddDesc_echo_2eproto(); + return *default_instance_; +} + +AddRequest* AddRequest::default_instance_ = NULL; + +AddRequest* AddRequest::New(::google::protobuf::Arena* arena) const { + AddRequest* n = new AddRequest; + if (arena != NULL) { + arena->Own(n); + } + return n; +} + +void AddRequest::Clear() { +#define ZR_HELPER_(f) reinterpret_cast(\ + &reinterpret_cast(16)->f) + +#define ZR_(first, last) do {\ + ::memset(&first, 0,\ + ZR_HELPER_(last) - ZR_HELPER_(first) + sizeof(last));\ +} while (0) + + ZR_(addx_, addy_); + +#undef ZR_HELPER_ +#undef ZR_ + +} + +bool AddRequest::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:echo.AddRequest) + for (;;) { + ::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoff(127); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // optional fixed32 addX = 1; + case 1: { + if (tag == 13) { + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_FIXED32>( + input, &addx_))); + + } else { + goto handle_unusual; + } + if (input->ExpectTag(21)) goto parse_addY; + break; + } + + // optional fixed32 addY = 2; + case 2: { + if (tag == 21) { + parse_addY: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_FIXED32>( + input, &addy_))); + + } else { + goto handle_unusual; + } + if (input->ExpectAtEnd()) goto success; + break; + } + + default: { + handle_unusual: + if (tag == 0 || + ::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + goto success; + } + DO_(::google::protobuf::internal::WireFormatLite::SkipField(input, tag)); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:echo.AddRequest) + return true; +failure: + // @@protoc_insertion_point(parse_failure:echo.AddRequest) + return false; +#undef DO_ +} + +void AddRequest::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:echo.AddRequest) + // optional fixed32 addX = 1; + if (this->addx() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteFixed32(1, this->addx(), output); + } + + // optional fixed32 addY = 2; + if (this->addy() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteFixed32(2, this->addy(), output); + } + + // @@protoc_insertion_point(serialize_end:echo.AddRequest) +} + +::google::protobuf::uint8* AddRequest::SerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // @@protoc_insertion_point(serialize_to_array_start:echo.AddRequest) + // optional fixed32 addX = 1; + if (this->addx() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteFixed32ToArray(1, this->addx(), target); + } + + // optional fixed32 addY = 2; + if (this->addy() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteFixed32ToArray(2, this->addy(), target); + } + + // @@protoc_insertion_point(serialize_to_array_end:echo.AddRequest) + return target; +} + +int AddRequest::ByteSize() const { + int total_size = 0; + + // optional fixed32 addX = 1; + if (this->addx() != 0) { + total_size += 1 + 4; + } + + // optional fixed32 addY = 2; + if (this->addy() != 0) { + total_size += 1 + 4; + } + + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = total_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void AddRequest::MergeFrom(const ::google::protobuf::Message& from) { + if (GOOGLE_PREDICT_FALSE(&from == this)) MergeFromFail(__LINE__); + const AddRequest* source = + ::google::protobuf::internal::DynamicCastToGenerated( + &from); + if (source == NULL) { + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + MergeFrom(*source); + } +} + +void AddRequest::MergeFrom(const AddRequest& from) { + if (GOOGLE_PREDICT_FALSE(&from == this)) MergeFromFail(__LINE__); + if (from.addx() != 0) { + set_addx(from.addx()); + } + if (from.addy() != 0) { + set_addy(from.addy()); + } +} + +void AddRequest::CopyFrom(const ::google::protobuf::Message& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void AddRequest::CopyFrom(const AddRequest& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool AddRequest::IsInitialized() const { + + return true; +} + +void AddRequest::Swap(AddRequest* other) { + if (other == this) return; + InternalSwap(other); +} +void AddRequest::InternalSwap(AddRequest* other) { + std::swap(addx_, other->addx_); + std::swap(addy_, other->addy_); + _internal_metadata_.Swap(&other->_internal_metadata_); + std::swap(_cached_size_, other->_cached_size_); +} + +::google::protobuf::Metadata AddRequest::GetMetadata() const { + protobuf_AssignDescriptorsOnce(); + ::google::protobuf::Metadata metadata; + metadata.descriptor = AddRequest_descriptor_; + metadata.reflection = AddRequest_reflection_; + return metadata; +} + +#if PROTOBUF_INLINE_NOT_IN_HEADERS +// AddRequest + +// optional fixed32 addX = 1; +void AddRequest::clear_addx() { + addx_ = 0u; +} + ::google::protobuf::uint32 AddRequest::addx() const { + // @@protoc_insertion_point(field_get:echo.AddRequest.addX) + return addx_; +} + void AddRequest::set_addx(::google::protobuf::uint32 value) { + + addx_ = value; + // @@protoc_insertion_point(field_set:echo.AddRequest.addX) +} + +// optional fixed32 addY = 2; +void AddRequest::clear_addy() { + addy_ = 0u; +} + ::google::protobuf::uint32 AddRequest::addy() const { + // @@protoc_insertion_point(field_get:echo.AddRequest.addY) + return addy_; +} + void AddRequest::set_addy(::google::protobuf::uint32 value) { + + addy_ = value; + // @@protoc_insertion_point(field_set:echo.AddRequest.addY) +} + +#endif // PROTOBUF_INLINE_NOT_IN_HEADERS + +// =================================================================== + +#if !defined(_MSC_VER) || _MSC_VER >= 1900 +const int AddResponse::kAnswerFieldNumber; +#endif // !defined(_MSC_VER) || _MSC_VER >= 1900 + +AddResponse::AddResponse() + : ::google::protobuf::Message(), _internal_metadata_(NULL) { + SharedCtor(); + // @@protoc_insertion_point(constructor:echo.AddResponse) +} + +void AddResponse::InitAsDefaultInstance() { + _is_default_instance_ = true; +} + +AddResponse::AddResponse(const AddResponse& from) + : ::google::protobuf::Message(), + _internal_metadata_(NULL) { + SharedCtor(); + MergeFrom(from); + // @@protoc_insertion_point(copy_constructor:echo.AddResponse) +} + +void AddResponse::SharedCtor() { + _is_default_instance_ = false; + _cached_size_ = 0; + answer_ = 0u; +} + +AddResponse::~AddResponse() { + // @@protoc_insertion_point(destructor:echo.AddResponse) + SharedDtor(); +} + +void AddResponse::SharedDtor() { + if (this != default_instance_) { + } +} + +void AddResponse::SetCachedSize(int size) const { + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); +} +const ::google::protobuf::Descriptor* AddResponse::descriptor() { + protobuf_AssignDescriptorsOnce(); + return AddResponse_descriptor_; +} + +const AddResponse& AddResponse::default_instance() { + if (default_instance_ == NULL) protobuf_AddDesc_echo_2eproto(); + return *default_instance_; +} + +AddResponse* AddResponse::default_instance_ = NULL; + +AddResponse* AddResponse::New(::google::protobuf::Arena* arena) const { + AddResponse* n = new AddResponse; + if (arena != NULL) { + arena->Own(n); + } + return n; +} + +void AddResponse::Clear() { + answer_ = 0u; +} + +bool AddResponse::MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input) { +#define DO_(EXPRESSION) if (!(EXPRESSION)) goto failure + ::google::protobuf::uint32 tag; + // @@protoc_insertion_point(parse_start:echo.AddResponse) + for (;;) { + ::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoff(127); + tag = p.first; + if (!p.second) goto handle_unusual; + switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { + // optional fixed32 answer = 1; + case 1: { + if (tag == 13) { + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_FIXED32>( + input, &answer_))); + + } else { + goto handle_unusual; + } + if (input->ExpectAtEnd()) goto success; + break; + } + + default: { + handle_unusual: + if (tag == 0 || + ::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == + ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { + goto success; + } + DO_(::google::protobuf::internal::WireFormatLite::SkipField(input, tag)); + break; + } + } + } +success: + // @@protoc_insertion_point(parse_success:echo.AddResponse) + return true; +failure: + // @@protoc_insertion_point(parse_failure:echo.AddResponse) + return false; +#undef DO_ +} + +void AddResponse::SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const { + // @@protoc_insertion_point(serialize_start:echo.AddResponse) + // optional fixed32 answer = 1; + if (this->answer() != 0) { + ::google::protobuf::internal::WireFormatLite::WriteFixed32(1, this->answer(), output); + } + + // @@protoc_insertion_point(serialize_end:echo.AddResponse) +} + +::google::protobuf::uint8* AddResponse::SerializeWithCachedSizesToArray( + ::google::protobuf::uint8* target) const { + // @@protoc_insertion_point(serialize_to_array_start:echo.AddResponse) + // optional fixed32 answer = 1; + if (this->answer() != 0) { + target = ::google::protobuf::internal::WireFormatLite::WriteFixed32ToArray(1, this->answer(), target); + } + + // @@protoc_insertion_point(serialize_to_array_end:echo.AddResponse) + return target; +} + +int AddResponse::ByteSize() const { + int total_size = 0; + + // optional fixed32 answer = 1; + if (this->answer() != 0) { + total_size += 1 + 4; + } + + GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); + _cached_size_ = total_size; + GOOGLE_SAFE_CONCURRENT_WRITES_END(); + return total_size; +} + +void AddResponse::MergeFrom(const ::google::protobuf::Message& from) { + if (GOOGLE_PREDICT_FALSE(&from == this)) MergeFromFail(__LINE__); + const AddResponse* source = + ::google::protobuf::internal::DynamicCastToGenerated( + &from); + if (source == NULL) { + ::google::protobuf::internal::ReflectionOps::Merge(from, this); + } else { + MergeFrom(*source); + } +} + +void AddResponse::MergeFrom(const AddResponse& from) { + if (GOOGLE_PREDICT_FALSE(&from == this)) MergeFromFail(__LINE__); + if (from.answer() != 0) { + set_answer(from.answer()); + } +} + +void AddResponse::CopyFrom(const ::google::protobuf::Message& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +void AddResponse::CopyFrom(const AddResponse& from) { + if (&from == this) return; + Clear(); + MergeFrom(from); +} + +bool AddResponse::IsInitialized() const { + + return true; +} + +void AddResponse::Swap(AddResponse* other) { + if (other == this) return; + InternalSwap(other); +} +void AddResponse::InternalSwap(AddResponse* other) { + std::swap(answer_, other->answer_); + _internal_metadata_.Swap(&other->_internal_metadata_); + std::swap(_cached_size_, other->_cached_size_); +} + +::google::protobuf::Metadata AddResponse::GetMetadata() const { + protobuf_AssignDescriptorsOnce(); + ::google::protobuf::Metadata metadata; + metadata.descriptor = AddResponse_descriptor_; + metadata.reflection = AddResponse_reflection_; + return metadata; +} + +#if PROTOBUF_INLINE_NOT_IN_HEADERS +// AddResponse + +// optional fixed32 answer = 1; +void AddResponse::clear_answer() { + answer_ = 0u; +} + ::google::protobuf::uint32 AddResponse::answer() const { + // @@protoc_insertion_point(field_get:echo.AddResponse.answer) + return answer_; +} + void AddResponse::set_answer(::google::protobuf::uint32 value) { + + answer_ = value; + // @@protoc_insertion_point(field_set:echo.AddResponse.answer) +} + +#endif // PROTOBUF_INLINE_NOT_IN_HEADERS + // @@protoc_insertion_point(namespace_scope) } // namespace echo diff --git a/examples/echo/echo-cpp/echo.pb.h b/examples/echo/echo-cpp/echo.pb.h index 089f94a..d430892 100644 --- a/examples/echo/echo-cpp/echo.pb.h +++ b/examples/echo/echo-cpp/echo.pb.h @@ -36,6 +36,8 @@ void protobuf_AddDesc_echo_2eproto(); void protobuf_AssignDesc_echo_2eproto(); void protobuf_ShutdownFile_echo_2eproto(); +class AddRequest; +class AddResponse; class EchoRequest; // =================================================================== @@ -121,6 +123,169 @@ class EchoRequest : public ::google::protobuf::Message { void InitAsDefaultInstance(); static EchoRequest* default_instance_; }; +// ------------------------------------------------------------------- + +class AddRequest : public ::google::protobuf::Message { + public: + AddRequest(); + virtual ~AddRequest(); + + AddRequest(const AddRequest& from); + + inline AddRequest& operator=(const AddRequest& from) { + CopyFrom(from); + return *this; + } + + static const ::google::protobuf::Descriptor* descriptor(); + static const AddRequest& default_instance(); + + void Swap(AddRequest* other); + + // implements Message ---------------------------------------------- + + inline AddRequest* New() const { return New(NULL); } + + AddRequest* New(::google::protobuf::Arena* arena) const; + void CopyFrom(const ::google::protobuf::Message& from); + void MergeFrom(const ::google::protobuf::Message& from); + void CopyFrom(const AddRequest& from); + void MergeFrom(const AddRequest& from); + void Clear(); + bool IsInitialized() const; + + int ByteSize() const; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input); + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const; + ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; + int GetCachedSize() const { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + void InternalSwap(AddRequest* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return _internal_metadata_.arena(); + } + inline void* MaybeArenaPtr() const { + return _internal_metadata_.raw_arena_ptr(); + } + public: + + ::google::protobuf::Metadata GetMetadata() const; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // optional fixed32 addX = 1; + void clear_addx(); + static const int kAddXFieldNumber = 1; + ::google::protobuf::uint32 addx() const; + void set_addx(::google::protobuf::uint32 value); + + // optional fixed32 addY = 2; + void clear_addy(); + static const int kAddYFieldNumber = 2; + ::google::protobuf::uint32 addy() const; + void set_addy(::google::protobuf::uint32 value); + + // @@protoc_insertion_point(class_scope:echo.AddRequest) + private: + + ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + bool _is_default_instance_; + ::google::protobuf::uint32 addx_; + ::google::protobuf::uint32 addy_; + mutable int _cached_size_; + friend void protobuf_AddDesc_echo_2eproto(); + friend void protobuf_AssignDesc_echo_2eproto(); + friend void protobuf_ShutdownFile_echo_2eproto(); + + void InitAsDefaultInstance(); + static AddRequest* default_instance_; +}; +// ------------------------------------------------------------------- + +class AddResponse : public ::google::protobuf::Message { + public: + AddResponse(); + virtual ~AddResponse(); + + AddResponse(const AddResponse& from); + + inline AddResponse& operator=(const AddResponse& from) { + CopyFrom(from); + return *this; + } + + static const ::google::protobuf::Descriptor* descriptor(); + static const AddResponse& default_instance(); + + void Swap(AddResponse* other); + + // implements Message ---------------------------------------------- + + inline AddResponse* New() const { return New(NULL); } + + AddResponse* New(::google::protobuf::Arena* arena) const; + void CopyFrom(const ::google::protobuf::Message& from); + void MergeFrom(const ::google::protobuf::Message& from); + void CopyFrom(const AddResponse& from); + void MergeFrom(const AddResponse& from); + void Clear(); + bool IsInitialized() const; + + int ByteSize() const; + bool MergePartialFromCodedStream( + ::google::protobuf::io::CodedInputStream* input); + void SerializeWithCachedSizes( + ::google::protobuf::io::CodedOutputStream* output) const; + ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; + int GetCachedSize() const { return _cached_size_; } + private: + void SharedCtor(); + void SharedDtor(); + void SetCachedSize(int size) const; + void InternalSwap(AddResponse* other); + private: + inline ::google::protobuf::Arena* GetArenaNoVirtual() const { + return _internal_metadata_.arena(); + } + inline void* MaybeArenaPtr() const { + return _internal_metadata_.raw_arena_ptr(); + } + public: + + ::google::protobuf::Metadata GetMetadata() const; + + // nested types ---------------------------------------------------- + + // accessors ------------------------------------------------------- + + // optional fixed32 answer = 1; + void clear_answer(); + static const int kAnswerFieldNumber = 1; + ::google::protobuf::uint32 answer() const; + void set_answer(::google::protobuf::uint32 value); + + // @@protoc_insertion_point(class_scope:echo.AddResponse) + private: + + ::google::protobuf::internal::InternalMetadataWithArena _internal_metadata_; + bool _is_default_instance_; + ::google::protobuf::uint32 answer_; + mutable int _cached_size_; + friend void protobuf_AddDesc_echo_2eproto(); + friend void protobuf_AssignDesc_echo_2eproto(); + friend void protobuf_ShutdownFile_echo_2eproto(); + + void InitAsDefaultInstance(); + static AddResponse* default_instance_; +}; // =================================================================== @@ -172,7 +337,61 @@ inline void EchoRequest::set_allocated_message(::std::string* message) { // @@protoc_insertion_point(field_set_allocated:echo.EchoRequest.message) } +// ------------------------------------------------------------------- + +// AddRequest + +// optional fixed32 addX = 1; +inline void AddRequest::clear_addx() { + addx_ = 0u; +} +inline ::google::protobuf::uint32 AddRequest::addx() const { + // @@protoc_insertion_point(field_get:echo.AddRequest.addX) + return addx_; +} +inline void AddRequest::set_addx(::google::protobuf::uint32 value) { + + addx_ = value; + // @@protoc_insertion_point(field_set:echo.AddRequest.addX) +} + +// optional fixed32 addY = 2; +inline void AddRequest::clear_addy() { + addy_ = 0u; +} +inline ::google::protobuf::uint32 AddRequest::addy() const { + // @@protoc_insertion_point(field_get:echo.AddRequest.addY) + return addy_; +} +inline void AddRequest::set_addy(::google::protobuf::uint32 value) { + + addy_ = value; + // @@protoc_insertion_point(field_set:echo.AddRequest.addY) +} + +// ------------------------------------------------------------------- + +// AddResponse + +// optional fixed32 answer = 1; +inline void AddResponse::clear_answer() { + answer_ = 0u; +} +inline ::google::protobuf::uint32 AddResponse::answer() const { + // @@protoc_insertion_point(field_get:echo.AddResponse.answer) + return answer_; +} +inline void AddResponse::set_answer(::google::protobuf::uint32 value) { + + answer_ = value; + // @@protoc_insertion_point(field_set:echo.AddResponse.answer) +} + #endif // !PROTOBUF_INLINE_NOT_IN_HEADERS +// ------------------------------------------------------------------- + +// ------------------------------------------------------------------- + // @@protoc_insertion_point(namespace_scope) diff --git a/examples/echo/echo-server/Main.hs b/examples/echo/echo-server/Main.hs index 8c97856..2d0070e 100644 --- a/examples/echo/echo-server/Main.hs +++ b/examples/echo/echo-server/Main.hs @@ -1,7 +1,8 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE OverloadedLists #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-} {-# OPTIONS_GHC -fno-warn-unused-binds #-} @@ -9,10 +10,14 @@ import Control.Concurrent import Control.Concurrent.Async import Control.Monad import Data.ByteString (ByteString) +import Data.Protobuf.Wire.Class +import qualified Data.Text as T +import Data.Word +import GHC.Generics (Generic) +import Network.GRPC.HighLevel.Server import Network.GRPC.LowLevel -import Network.GRPC.LowLevel.Call +import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Server.Unregistered as U -import qualified Network.GRPC.LowLevel.Call.Unregistered as U serverMeta :: MetadataMap serverMeta = [("test_meta", "test_meta_value")] @@ -27,7 +32,7 @@ handler U.ServerCall{..} reqBody = do unregMain :: IO () unregMain = withGRPC $ \grpc -> do - withServer grpc (ServerConfig "localhost" 50051 [] []) $ \server -> forever $ do + withServer grpc defConfig $ \server -> forever $ do result <- U.serverHandleNormalCall server serverMeta handler case result of Left x -> putStrLn $ "handle call result error: " ++ show x @@ -35,8 +40,8 @@ unregMain = withGRPC $ \grpc -> do regMain :: IO () regMain = withGRPC $ \grpc -> do - let methods = [(MethodName "/echo.Echo/DoEcho", Normal)] - withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> + let ms = [(MethodName "/echo.Echo/DoEcho")] + withServer grpc (defConfig {methodsToRegisterNormal = ms}) $ \server -> forever $ do let method = head (normalMethods server) result <- serverHandleNormalCall server method serverMeta $ @@ -63,13 +68,56 @@ regLoop server method = forever $ do regMainThreaded :: IO () regMainThreaded = do withGRPC $ \grpc -> do - let methods = [(MethodName "/echo.Echo/DoEcho", Normal)] - withServer grpc (ServerConfig "localhost" 50051 methods []) $ \server -> do + let ms = [(MethodName "/echo.Echo/DoEcho")] + withServer grpc (defConfig {methodsToRegisterNormal = ms}) $ \server -> do let method = head (normalMethods server) tids <- replicateM 7 $ async $ do tputStrLn "starting handler" regLoop server method waitAnyCancel tids tputStrLn "finishing" +-- NB: If you change these, make sure to change them in the client as well. +-- TODO: Put these in a common location (or just hack around it until CG is working) +data EchoRequest = EchoRequest {message :: T.Text} deriving (Show, Eq, Ord, Generic) +instance Message EchoRequest +data AddRequest = AddRequest {addX :: Word32, addY :: Word32} deriving (Show, Eq, Ord, Generic) +instance Message AddRequest +data AddResponse = AddResponse {answer :: Word32} deriving (Show, Eq, Ord, Generic) +instance Message AddResponse + +highlevelMain :: IO () +highlevelMain = + serverLoop defaultOptions{optNormalHandlers = [echoHandler, addHandler]} + where echoHandler = + UnaryHandler "/echo.Echo/DoEcho" $ + \_c body m -> do + tputStrLn $ "UnaryHandler for DoEcho hit, body=" ++ show body + return ( body :: EchoRequest + , m + , StatusOk + , StatusDetails "" + ) + addHandler = + --TODO: I can't get this one to execute. Is the generated method + --name different? + + -- static const char* Add_method_names[] = { + -- "/echo.Add/DoAdd", + -- }; + + UnaryHandler "/echo.Add/DoAdd" $ + \_c b m -> do + tputStrLn $ "UnaryHandler for DoAdd hit, b=" ++ show b + print (addX b) + print (addY b) + return ( AddResponse $ addX b + addY b + , m + , StatusOk + , StatusDetails "" + ) + main :: IO () -main = regMainThreaded +main = highlevelMain + +defConfig :: ServerConfig +defConfig = ServerConfig "localhost" 50051 [] [] [] [] [] diff --git a/examples/echo/echo.proto b/examples/echo/echo.proto index ed192c7..5ac72fe 100644 --- a/examples/echo/echo.proto +++ b/examples/echo/echo.proto @@ -9,3 +9,16 @@ service Echo { message EchoRequest { string message = 1; } + +message AddRequest { + fixed32 addX = 1; + fixed32 addY = 2; +} + +message AddResponse { + fixed32 answer = 1; +} + +service Add { + rpc DoAdd (AddRequest) returns (AddResponse) {} +} diff --git a/grpc-haskell.cabal b/grpc-haskell.cabal index 5da6c4f..f29fb32 100644 --- a/grpc-haskell.cabal +++ b/grpc-haskell.cabal @@ -33,6 +33,8 @@ library , managed >= 1.0.5 && < 1.1 , pipes ==4.1.* , transformers + , proto3-wire + , protobuf-wire , async , tasty >= 0.11 && <0.12 @@ -63,6 +65,8 @@ library Network.GRPC.LowLevel.Call Network.GRPC.LowLevel.Call.Unregistered Network.GRPC.LowLevel.Client + Network.GRPC.HighLevel + Network.GRPC.HighLevel.Server extra-libraries: grpc includes: @@ -86,11 +90,14 @@ library executable echo-server if flag(with-examples) build-depends: - base ==4.8.* - , bytestring == 0.10.* - , grpc-haskell - , containers ==0.5.* + base ==4.8.* , async + , bytestring == 0.10.* + , containers ==0.5.* + , grpc-haskell + , proto3-wire + , protobuf-wire + , text else buildable: False default-language: Haskell2010 @@ -101,10 +108,14 @@ executable echo-server executable echo-client if flag(with-examples) build-depends: - base ==4.8.* + base ==4.8.* + , async , bytestring == 0.10.* - , grpc-haskell , containers ==0.5.* + , grpc-haskell + , proto3-wire + , protobuf-wire + , text else buildable: False default-language: Haskell2010 diff --git a/src/Network/GRPC/HighLevel.hs b/src/Network/GRPC/HighLevel.hs new file mode 100644 index 0000000..a230a3d --- /dev/null +++ b/src/Network/GRPC/HighLevel.hs @@ -0,0 +1,10 @@ +module Network.GRPC.HighLevel ( +-- * Server +Handler(..) +, ServerOptions(..) +, defaultOptions +, serverLoop +) + where + +import Network.GRPC.HighLevel.Server diff --git a/src/Network/GRPC/HighLevel/Server.hs b/src/Network/GRPC/HighLevel/Server.hs new file mode 100644 index 0000000..0b9b1cd --- /dev/null +++ b/src/Network/GRPC/HighLevel/Server.hs @@ -0,0 +1,223 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} + +module Network.GRPC.HighLevel.Server where + +import Control.Concurrent.Async +import Control.Monad +import Data.ByteString (ByteString) +import qualified Data.ByteString.Lazy as BL +import Data.Protobuf.Wire.Class +import Network.GRPC.LowLevel +import qualified Network.GRPC.LowLevel.Call.Unregistered as U +import Network.GRPC.LowLevel.GRPC +import qualified Network.GRPC.LowLevel.Server.Unregistered as U + +type ServerHandler' a b = + forall c . + ServerCall c + -> a + -> MetadataMap + -> IO (b, MetadataMap, StatusCode, StatusDetails) + +convertServerHandler :: (Message a, Message b) + => ServerHandler' a b + -> ServerHandler +convertServerHandler f c bs m = case fromByteString bs of + Left{} -> error "TODO: find a way to keep this from killing the server." + Right x -> do (y, tm, sc, sd) <- f c x m + return (toBS y, tm, sc, sd) + +type ServerReaderHandler' a b = + ServerCall () + -> StreamRecv a + -> Streaming (Maybe b, MetadataMap, StatusCode, StatusDetails) + +convertServerReaderHandler :: (Message a, Message b) + => ServerReaderHandler' a b + -> ServerReaderHandler +convertServerReaderHandler f c recv = + serialize <$> f c (convertRecv recv) + where + serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd) + +type ServerWriterHandler' a b = + ServerCall a + -> StreamSend b + -> Streaming (MetadataMap, StatusCode, StatusDetails) + +convertServerWriterHandler :: (Message a, Message b) => + ServerWriterHandler' a b + -> ServerWriterHandler +convertServerWriterHandler f c send = + f (convert <$> c) (convertSend send) + where + convert bs = case fromByteString bs of + Left x -> error $ "deserialization error: " ++ show x -- TODO FIXME + Right x -> x + +type ServerRWHandler' a b = + ServerCall () + -> StreamRecv a + -> StreamSend b + -> Streaming (MetadataMap, StatusCode, StatusDetails) + +convertServerRWHandler :: (Message a, Message b) + => ServerRWHandler' a b + -> ServerRWHandler +convertServerRWHandler f c recv send = + f c (convertRecv recv) (convertSend send) + +convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a +convertRecv = + fmap $ \e -> do + msg <- e + case msg of + Nothing -> return Nothing + Just bs -> case fromByteString bs of + Left x -> Left (GRPCIODecodeError x) + Right x -> return (Just x) + +convertSend :: Message a => StreamSend ByteString -> StreamSend a +convertSend s = s . toBS + +toBS :: Message a => a -> ByteString +toBS = BL.toStrict . toLazyByteString + +data Handler a where + UnaryHandler + :: (Message c, Message d) + => MethodName + -> ServerHandler' c d + -> Handler 'Normal + + ClientStreamHandler + :: (Message c, Message d) + => MethodName + -> ServerReaderHandler' c d + -> Handler 'ClientStreaming + + ServerStreamHandler + :: (Message c, Message d) + => MethodName + -> ServerWriterHandler' c d + -> Handler 'ServerStreaming + + BiDiStreamHandler + :: (Message c, Message d) + => MethodName + -> ServerRWHandler' c d + -> Handler 'BiDiStreaming + +handlerMethodName :: Handler a -> MethodName +handlerMethodName (UnaryHandler m _) = m +handlerMethodName (ClientStreamHandler m _) = m +handlerMethodName (ServerStreamHandler m _) = m +handlerMethodName (BiDiStreamHandler m _) = m + +-- TODO: find some idiomatic way to do logging that doesn't force the user +-- into anything they don't want. +logShow :: Show a => a -> IO () +logShow = print + +logAskReport :: Show a => a -> IO () +logAskReport x = + logShow $ show x ++ " This probably indicates a bug in gRPC-haskell. Please report this error." + +-- | Handles errors that result from trying to handle a call on the server. +-- For each error, takes a different action depending on the severity in the +-- context of handling a server call. This also tries to give an indication of +-- whether the error is our fault or user error. +handleCallError :: Either GRPCIOError a -> IO () +handleCallError (Right _) = return () +handleCallError (Left GRPCIOTimeout) = + --Probably a benign timeout (such as a client disappearing), noop for now. + return () +handleCallError (Left GRPCIOShutdown) = + --Server shutting down. Benign. + return () +handleCallError (Left x) = logAskReport x + +loopWError :: IO (Either GRPCIOError a) -> IO () +loopWError f = forever $ f >>= handleCallError + +--TODO: options for setting initial/trailing metadata +handleLoop :: Server -> (Handler a, RegisteredMethod a) -> IO () +handleLoop s (UnaryHandler _ f, rm) = + loopWError $ do + grpcDebug' "handleLoop about to block on serverHandleNormalCall" + serverHandleNormalCall s rm mempty $ convertServerHandler f +handleLoop s (ClientStreamHandler _ f, rm) = + loopWError $ serverReader s rm mempty $ convertServerReaderHandler f +handleLoop s (ServerStreamHandler _ f, rm) = + loopWError $ serverWriter s rm mempty $ convertServerWriterHandler f +handleLoop s (BiDiStreamHandler _ f, rm) = + loopWError $ serverRW s rm mempty $ convertServerRWHandler f + +data ServerOptions = ServerOptions + {optNormalHandlers :: [Handler 'Normal], + optClientStreamHandlers :: [Handler 'ClientStreaming], + optServerStreamHandlers :: [Handler 'ServerStreaming], + optBiDiStreamHandlers :: [Handler 'BiDiStreaming], + optServerPort :: Port, + optUseCompression :: Bool, + optUserAgentPrefix :: String, + optUserAgentSuffix :: String} + +defaultOptions :: ServerOptions +defaultOptions = + ServerOptions {optNormalHandlers = [], + optClientStreamHandlers = [], + optServerStreamHandlers = [], + optBiDiStreamHandlers = [], + optServerPort = 50051, + optUseCompression = False, + optUserAgentPrefix = "grpc-haskell/0.0.0", + optUserAgentSuffix = ""} + +serverLoop :: ServerOptions -> IO () +serverLoop opts = + withGRPC $ \grpc -> + withServer grpc (mkConfig opts) $ \server -> do + let rmsN = zip (optNormalHandlers opts) $ normalMethods server + let rmsCS = zip (optClientStreamHandlers opts) $ cstreamingMethods server + let rmsSS = zip (optServerStreamHandlers opts) $ sstreamingMethods server + let rmsB = zip (optBiDiStreamHandlers opts) $ bidiStreamingMethods server + --TODO: Perhaps assert that no methods disappeared after registration. + let loop :: forall a. (Handler a, RegisteredMethod a) -> IO () + loop = handleLoop server + asyncsN <- mapM async $ map loop rmsN + asyncsCS <- mapM async $ map loop rmsCS + asyncsSS <- mapM async $ map loop rmsSS + asyncsB <- mapM async $ map loop rmsB + asyncUnk <- async $ loopWError $ unknownHandler server + waitAnyCancel $ asyncUnk : asyncsN ++ asyncsCS ++ asyncsSS ++ asyncsB + return () + where + mkConfig ServerOptions{..} = + ServerConfig + { host = "localhost" + , port = optServerPort + , methodsToRegisterNormal = map handlerMethodName optNormalHandlers + , methodsToRegisterClientStreaming = + map handlerMethodName optClientStreamHandlers + , methodsToRegisterServerStreaming = + map handlerMethodName optServerStreamHandlers + , methodsToRegisterBiDiStreaming = + map handlerMethodName optBiDiStreamHandlers + , serverArgs = + ([CompressionAlgArg GrpcCompressDeflate | optUseCompression] + ++ + [UserAgentPrefix optUserAgentPrefix + , UserAgentSuffix optUserAgentSuffix]) + } + unknownHandler s = + --TODO: is this working? + U.serverHandleNormalCall s mempty $ \call _ -> do + logShow $ "Requested unknown endpoint: " ++ show (U.callMethod call) + return ("", mempty, StatusNotFound, + StatusDetails "Unknown method") diff --git a/src/Network/GRPC/LowLevel.hs b/src/Network/GRPC/LowLevel.hs index bc69028..1a43277 100644 --- a/src/Network/GRPC/LowLevel.hs +++ b/src/Network/GRPC/LowLevel.hs @@ -29,6 +29,7 @@ GRPC -- * Configuration options , Arg(..) , CompressionAlgorithm(..) +, Port -- * Server , ServerConfig(..) @@ -37,12 +38,16 @@ GRPC , ServerCall(optionalPayload, requestMetadataRecv) , withServer , serverHandleNormalCall +, ServerHandler , withServerCall , serverCallCancel , serverCallIsExpired , serverReader -- for client streaming +, ServerReaderHandler , serverWriter -- for server streaming +, ServerWriterHandler , serverRW -- for bidirectional streaming +, ServerRWHandler -- * Client , ClientConfig(..) @@ -51,7 +56,10 @@ GRPC , ConnectivityState(..) , clientConnectivity , withClient -, clientRegisterMethod +, clientRegisterMethodNormal +, clientRegisterMethodClientStreaming +, clientRegisterMethodServerStreaming +, clientRegisterMethodBiDiStreaming , clientRequest , clientReader -- for server streaming , clientWriter -- for client streaming @@ -64,6 +72,11 @@ GRPC , Op(..) , OpRecvResult(..) +-- * Streaming utilities +, Streaming +, StreamSend +, StreamRecv + ) where import Network.GRPC.LowLevel.GRPC diff --git a/src/Network/GRPC/LowLevel/Call.hs b/src/Network/GRPC/LowLevel/Call.hs index 91e79ad..d850522 100644 --- a/src/Network/GRPC/LowLevel/Call.hs +++ b/src/Network/GRPC/LowLevel/Call.hs @@ -1,26 +1,35 @@ {-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} -- | This module defines data structures and operations pertaining to registered -- calls; for unregistered call support, see -- `Network.GRPC.LowLevel.Call.Unregistered`. module Network.GRPC.LowLevel.Call where -import Data.ByteString (ByteString) -import Data.String (IsString) -#ifdef DEBUG -import Foreign.Storable (peek) -#endif +import Control.Monad.Managed (Managed, managed) +import Control.Exception (bracket) +import Data.ByteString (ByteString) +import Data.List (intersperse) +import Data.String (IsString) +import Foreign.Marshal.Alloc (free, malloc) +import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Storable (Storable, peek) +import Network.GRPC.LowLevel.CompletionQueue.Internal +import Network.GRPC.LowLevel.GRPC (MetadataMap, + grpcDebug) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.ByteBuffer as C +import qualified Network.GRPC.Unsafe.Op as C import System.Clock -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Op as C - -import Network.GRPC.LowLevel.GRPC (MetadataMap, grpcDebug) - -- | Models the four types of RPC call supported by gRPC (and correspond to -- DataKinds phantom types on RegisteredMethods). data GRPCMethodType @@ -30,6 +39,23 @@ data GRPCMethodType | BiDiStreaming deriving (Show, Eq, Ord, Enum) +type family MethodPayload a where + MethodPayload 'Normal = ByteString + MethodPayload 'ClientStreaming = () + MethodPayload 'ServerStreaming = ByteString + MethodPayload 'BiDiStreaming = () + +--TODO: try replacing this class with a plain old function so we don't have the +-- Payloadable constraint everywhere. + +payload :: RegisteredMethod mt -> Ptr C.ByteBuffer -> IO (MethodPayload mt) +payload (RegisteredMethodNormal _ _ _) p = + peek p >>= C.copyByteBufferToByteString +payload (RegisteredMethodClientStreaming _ _ _) _ = return () +payload (RegisteredMethodServerStreaming _ _ _) p = + peek p >>= C.copyByteBufferToByteString +payload (RegisteredMethodBiDiStreaming _ _ _) _ = return () + newtype MethodName = MethodName {unMethodName :: String} deriving (Show, Eq, IsString) @@ -53,32 +79,82 @@ endpoint (Host h) (Port p) = Endpoint (h ++ ":" ++ show p) -- Contains state for identifying that method in the underlying gRPC -- library. Note that we use a DataKind-ed phantom type to help constrain use of -- different kinds of registered methods. -data RegisteredMethod (mt :: GRPCMethodType) = RegisteredMethod - { methodType :: GRPCMethodType - , methodName :: MethodName - , methodEndpoint :: Endpoint - , methodHandle :: C.CallHandle - } +data RegisteredMethod (mt :: GRPCMethodType) where + RegisteredMethodNormal :: MethodName + -> Endpoint + -> C.CallHandle + -> RegisteredMethod 'Normal + RegisteredMethodClientStreaming :: MethodName + -> Endpoint + -> C.CallHandle + -> RegisteredMethod 'ClientStreaming + RegisteredMethodServerStreaming :: MethodName + -> Endpoint + -> C.CallHandle + -> RegisteredMethod 'ServerStreaming + RegisteredMethodBiDiStreaming :: MethodName + -> Endpoint + -> C.CallHandle + -> RegisteredMethod 'BiDiStreaming + +instance Show (RegisteredMethod a) where + show (RegisteredMethodNormal x y z) = + "RegisteredMethodNormal " + ++ concat (intersperse " " [show x, show y, show z]) + show (RegisteredMethodClientStreaming x y z) = + "RegisteredMethodClientStreaming " + ++ concat (intersperse " " [show x, show y, show z]) + show (RegisteredMethodServerStreaming x y z) = + "RegisteredMethodServerStreaming " + ++ concat (intersperse " " [show x, show y, show z]) + show (RegisteredMethodBiDiStreaming x y z) = + "RegisteredMethodBiDiStreaming " + ++ concat (intersperse " " [show x, show y, show z]) + +methodName :: RegisteredMethod mt -> MethodName +methodName (RegisteredMethodNormal x _ _) = x +methodName (RegisteredMethodClientStreaming x _ _) = x +methodName (RegisteredMethodServerStreaming x _ _) = x +methodName (RegisteredMethodBiDiStreaming x _ _) = x + +methodEndpoint :: RegisteredMethod mt -> Endpoint +methodEndpoint (RegisteredMethodNormal _ x _) = x +methodEndpoint (RegisteredMethodClientStreaming _ x _) = x +methodEndpoint (RegisteredMethodServerStreaming _ x _) = x +methodEndpoint (RegisteredMethodBiDiStreaming _ x _) = x + +methodHandle :: RegisteredMethod mt -> C.CallHandle +methodHandle (RegisteredMethodNormal _ _ x) = x +methodHandle (RegisteredMethodClientStreaming _ _ x) = x +methodHandle (RegisteredMethodServerStreaming _ _ x) = x +methodHandle (RegisteredMethodBiDiStreaming _ _ x) = x + +methodType :: RegisteredMethod mt -> GRPCMethodType +methodType (RegisteredMethodNormal _ _ _) = Normal +methodType (RegisteredMethodClientStreaming _ _ _) = ClientStreaming +methodType (RegisteredMethodServerStreaming _ _ _) = ServerStreaming +methodType (RegisteredMethodBiDiStreaming _ _ _) = BiDiStreaming -- | Represents one GRPC call (i.e. request) on the client. -- This is used to associate send/receive 'Op's with a request. -data ClientCall = ClientCall { unClientCall :: C.Call } +data ClientCall = ClientCall { unsafeCC :: C.Call } clientCallCancel :: ClientCall -> IO () -clientCallCancel cc = C.grpcCallCancel (unClientCall cc) C.reserved +clientCallCancel cc = C.grpcCallCancel (unsafeCC cc) C.reserved -- | Represents one registered GRPC call on the server. Contains pointers to all -- the C state needed to respond to a registered call. -data ServerCall = ServerCall - { unServerCall :: C.Call, - requestMetadataRecv :: MetadataMap, - optionalPayload :: Maybe ByteString, - callDeadline :: TimeSpec - } +data ServerCall a = ServerCall + { unsafeSC :: C.Call + , callCQ :: CompletionQueue + , requestMetadataRecv :: MetadataMap + , optionalPayload :: a + , callDeadline :: TimeSpec + } deriving (Functor, Show) -serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO () +serverCallCancel :: ServerCall a -> C.StatusCode -> String -> IO () serverCallCancel sc code reason = - C.grpcCallCancelWithStatus (unServerCall sc) code reason C.reserved + C.grpcCallCancelWithStatus (unsafeSC sc) code reason C.reserved -- | NB: For now, we've assumed that the method type is all the info we need to -- decide the server payload handling method. @@ -88,7 +164,17 @@ payloadHandling ClientStreaming = C.SrmPayloadNone payloadHandling ServerStreaming = C.SrmPayloadReadInitialByteBuffer payloadHandling BiDiStreaming = C.SrmPayloadNone -serverCallIsExpired :: ServerCall -> IO Bool +-- | Optionally allocate a managed byte buffer for a payload, depending on the +-- given method type. If no payload is needed, the returned pointer is null +mgdPayload :: GRPCMethodType -> Managed (Ptr C.ByteBuffer) +mgdPayload mt + | payloadHandling mt == C.SrmPayloadNone = return nullPtr + | otherwise = managed C.withByteBufferPtr + +mgdPtr :: forall a. Storable a => Managed (Ptr a) +mgdPtr = managed (bracket malloc free) + +serverCallIsExpired :: ServerCall a -> IO Bool serverCallIsExpired sc = do currTime <- getTime Monotonic return $ currTime > (callDeadline sc) @@ -102,27 +188,27 @@ debugClientCall (ClientCall (C.Call ptr)) = debugClientCall = const $ return () #endif -debugServerCall :: ServerCall -> IO () +debugServerCall :: ServerCall a -> IO () #ifdef DEBUG -debugServerCall call@(ServerCall (C.Call ptr) _ _ _) = do - grpcDebug $ "debugServerCall(R): server call: " ++ (show ptr) - grpcDebug $ "debugServerCall(R): metadata ptr: " - ++ show (requestMetadataRecv call) - grpcDebug $ "debugServerCall(R): payload ptr: " ++ show (optionalPayload call) - grpcDebug $ "debugServerCall(R): deadline ptr: " ++ show (callDeadline call) +debugServerCall sc@(ServerCall (C.Call ptr) _ _ _ _) = do + let dbug = grpcDebug . ("debugServerCall(R): " ++) + dbug $ "server call: " ++ show ptr + dbug $ "callCQ: " ++ show (callCQ sc) + dbug $ "metadata ptr: " ++ show (requestMetadataRecv sc) + dbug $ "deadline ptr: " ++ show (callDeadline sc) #else {-# INLINE debugServerCall #-} debugServerCall = const $ return () #endif destroyClientCall :: ClientCall -> IO () -destroyClientCall ClientCall{..} = do +destroyClientCall cc = do grpcDebug "Destroying client-side call object." - C.grpcCallDestroy unClientCall + C.grpcCallDestroy (unsafeCC cc) -destroyServerCall :: ServerCall -> IO () -destroyServerCall call@ServerCall{..} = do +destroyServerCall :: ServerCall a -> IO () +destroyServerCall sc@ServerCall{ unsafeSC = c } = do grpcDebug "destroyServerCall(R): entered." - debugServerCall call - grpcDebug $ "Destroying server-side call object: " ++ show unServerCall - C.grpcCallDestroy unServerCall + debugServerCall sc + grpcDebug $ "Destroying server-side call object: " ++ show c + C.grpcCallDestroy c diff --git a/src/Network/GRPC/LowLevel/Call/Unregistered.hs b/src/Network/GRPC/LowLevel/Call/Unregistered.hs index f37d22d..a125cf2 100644 --- a/src/Network/GRPC/LowLevel/Call/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Call/Unregistered.hs @@ -3,23 +3,25 @@ module Network.GRPC.LowLevel.Call.Unregistered where import Control.Monad -import Foreign.Marshal.Alloc (free) -import Foreign.Ptr (Ptr) +import Foreign.Marshal.Alloc (free) +import Foreign.Ptr (Ptr) #ifdef DEBUG import Foreign.Storable (peek) #endif -import System.Clock (TimeSpec) +import Network.GRPC.LowLevel.Call (Host (..), + MethodName (..)) +import Network.GRPC.LowLevel.CompletionQueue.Internal +import Network.GRPC.LowLevel.GRPC (MetadataMap, + grpcDebug) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Op as C +import System.Clock (TimeSpec) -import Network.GRPC.LowLevel.Call (Host (..), MethodName (..)) -import Network.GRPC.LowLevel.GRPC (MetadataMap, grpcDebug) -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Op as C - --- | Represents one unregistered GRPC call on the server. --- Contains pointers to all the C state needed to respond to an unregistered --- call. +-- | Represents one unregistered GRPC call on the server. Contains pointers to +-- all the C state needed to respond to an unregistered call. data ServerCall = ServerCall - { unServerCall :: C.Call + { unsafeSC :: C.Call + , callCQ :: CompletionQueue , requestMetadataRecv :: MetadataMap , parentPtr :: Maybe (Ptr C.Call) , callDeadline :: TimeSpec @@ -29,18 +31,25 @@ data ServerCall = ServerCall serverCallCancel :: ServerCall -> C.StatusCode -> String -> IO () serverCallCancel sc code reason = - C.grpcCallCancelWithStatus (unServerCall sc) code reason C.reserved + C.grpcCallCancelWithStatus (unsafeSC sc) code reason C.reserved debugServerCall :: ServerCall -> IO () #ifdef DEBUG debugServerCall ServerCall{..} = do - let C.Call ptr = unServerCall - grpcDebug $ "debugServerCall(U): server call: " ++ show ptr - grpcDebug $ "debugServerCall(U): metadata: " - ++ show requestMetadataRecv - grpcDebug $ "debugServerCall(U): deadline: " ++ show callDeadline - grpcDebug $ "debugServerCall(U): method: " ++ show callMethod - grpcDebug $ "debugServerCall(U): host: " ++ show callHost + let C.Call ptr = unsafeSC + dbug = grpcDebug . ("debugServerCall(U): " ++) + + dbug $ "server call: " ++ show ptr + dbug $ "metadata: " ++ show requestMetadataRecv + + forM_ parentPtr $ \parentPtr' -> do + dbug $ "parent ptr: " ++ show parentPtr' + C.Call parent <- peek parentPtr' + dbug $ "parent: " ++ show parent + + dbug $ "deadline: " ++ show callDeadline + dbug $ "method: " ++ show callMethod + dbug $ "host: " ++ show callHost #else {-# INLINE debugServerCall #-} debugServerCall = const $ return () @@ -50,7 +59,7 @@ destroyServerCall :: ServerCall -> IO () destroyServerCall call@ServerCall{..} = do grpcDebug "destroyServerCall(U): entered." debugServerCall call - grpcDebug $ "Destroying server-side call object: " ++ show unServerCall - C.grpcCallDestroy unServerCall + grpcDebug $ "Destroying server-side call object: " ++ show unsafeSC + C.grpcCallDestroy unsafeSC grpcDebug $ "freeing parentPtr: " ++ show parentPtr forM_ parentPtr free diff --git a/src/Network/GRPC/LowLevel/Client.hs b/src/Network/GRPC/LowLevel/Client.hs index 975200b..9cadd96 100644 --- a/src/Network/GRPC/LowLevel/Client.hs +++ b/src/Network/GRPC/LowLevel/Client.hs @@ -10,10 +10,8 @@ -- `Network.GRPC.LowLevel.Client.Unregistered`. module Network.GRPC.LowLevel.Client where -import Control.Arrow import Control.Exception (bracket, finally) import Control.Monad -import Control.Monad.Trans.Class (MonadTrans(lift)) import Control.Monad.Trans.Except import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call @@ -26,9 +24,6 @@ import qualified Network.GRPC.Unsafe.Constants as C import qualified Network.GRPC.Unsafe.Op as C import qualified Network.GRPC.Unsafe.Time as C -import qualified Pipes as P -import qualified Pipes.Core as P - -- | Represents the context needed to perform client-side gRPC operations. data Client = Client {clientChannel :: C.Channel, clientCQ :: CompletionQueue, @@ -76,17 +71,54 @@ clientConnectivity :: Client -> IO C.ConnectivityState clientConnectivity Client{..} = C.grpcChannelCheckConnectivityState clientChannel False +--TODO: We should probably also register client methods on startup. + -- | Register a method on the client so that we can call it with -- 'clientRequest'. clientRegisterMethod :: Client -> MethodName - -> GRPCMethodType - -> IO (RegisteredMethod mt) -clientRegisterMethod Client{..} meth mty = do + -> IO (C.CallHandle) +clientRegisterMethod Client{..} meth = do let e = clientEndpoint clientConfig - RegisteredMethod mty meth e <$> - C.grpcChannelRegisterCall clientChannel - (unMethodName meth) (unEndpoint e) C.reserved + C.grpcChannelRegisterCall clientChannel + (unMethodName meth) + (unEndpoint e) + C.reserved + + +clientRegisterMethodNormal :: Client + -> MethodName + -> IO (RegisteredMethod 'Normal) +clientRegisterMethodNormal c meth = do + let e = clientEndpoint (clientConfig c) + h <- clientRegisterMethod c meth + return $ RegisteredMethodNormal meth e h + + +clientRegisterMethodClientStreaming :: Client + -> MethodName + -> IO (RegisteredMethod 'ClientStreaming) +clientRegisterMethodClientStreaming c meth = do + let e = clientEndpoint (clientConfig c) + h <- clientRegisterMethod c meth + return $ RegisteredMethodClientStreaming meth e h + +clientRegisterMethodServerStreaming :: Client + -> MethodName + -> IO (RegisteredMethod 'ServerStreaming) +clientRegisterMethodServerStreaming c meth = do + let e = clientEndpoint (clientConfig c) + h <- clientRegisterMethod c meth + return $ RegisteredMethodServerStreaming meth e h + + +clientRegisterMethodBiDiStreaming :: Client + -> MethodName + -> IO (RegisteredMethod 'BiDiStreaming) +clientRegisterMethodBiDiStreaming c meth = do + let e = clientEndpoint (clientConfig c) + h <- clientRegisterMethod c meth + return $ RegisteredMethodBiDiStreaming meth e h -- | Create a new call on the client for a registered method. -- Returns 'Left' if the CQ is shutting down or if the job to create a call @@ -103,13 +135,13 @@ clientCreateCall c rm ts = clientCreateCallParent c rm ts Nothing clientCreateCallParent :: Client -> RegisteredMethod mt -> TimeoutSeconds - -> Maybe ServerCall + -> Maybe (ServerCall a) -- ^ Optional parent call for cascading cancellation. -> IO (Either GRPCIOError ClientCall) -clientCreateCallParent Client{..} RegisteredMethod{..} timeout parent = do +clientCreateCallParent Client{..} rm timeout parent = do C.withDeadlineSeconds timeout $ \deadline -> do channelCreateCall clientChannel parent C.propagateDefaults - clientCQ methodHandle deadline + clientCQ (methodHandle rm) deadline -- | Handles safe creation and cleanup of a client call withClientCall :: Client @@ -117,27 +149,26 @@ withClientCall :: Client -> TimeoutSeconds -> (ClientCall -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) -withClientCall client regmethod timeout f = - withClientCallParent client regmethod timeout Nothing f +withClientCall cl rm tm = withClientCallParent cl rm tm Nothing -- | Handles safe creation and cleanup of a client call, with an optional parent -- call parameter. This allows for cancellation to cascade from the parent -- `ServerCall` to the created `ClientCall`. Obviously, this is only useful if -- the given gRPC client is also a server. withClientCallParent :: Client - -> RegisteredMethod mt - -> TimeoutSeconds - -> (Maybe ServerCall) - -- ^ Optional parent call for cascading cancellation. - -> (ClientCall -> IO (Either GRPCIOError a)) - -> IO (Either GRPCIOError a) -withClientCallParent client regmethod timeout parent f = do - createResult <- clientCreateCallParent client regmethod timeout parent - case createResult of - Left x -> return $ Left x - Right call -> f call `finally` logDestroy call - where logDestroy c = grpcDebug "withClientCall(R): destroying." - >> destroyClientCall c + -> RegisteredMethod mt + -> TimeoutSeconds + -> Maybe (ServerCall a) + -- ^ Optional parent call for cascading cancellation + -> (ClientCall -> IO (Either GRPCIOError a)) + -> IO (Either GRPCIOError a) +withClientCallParent cl rm tm parent f = + clientCreateCallParent cl rm tm parent >>= \case + Left e -> return (Left e) + Right c -> f c `finally` do + debugClientCall c + grpcDebug "withClientCall(R): destroying." + destroyClientCall c data NormalRequestResult = NormalRequestResult { rspBody :: ByteString @@ -166,7 +197,7 @@ compileNormalRequestResults x = -- clientReader (client side of server streaming mode) -- | First parameter is initial server metadata. -type ClientReaderHandler = MetadataMap -> StreamRecv -> Streaming () +type ClientReaderHandler = MetadataMap -> StreamRecv ByteString -> Streaming () clientReader :: Client -> RegisteredMethod 'ServerStreaming @@ -178,8 +209,7 @@ clientReader :: Client clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = withClientCall cl rm tm go where - go cc@(unClientCall -> c) = runExceptT $ do - lift (debugClientCall cc) + go (unsafeCC -> c) = runExceptT $ do runOps' c cq [ OpSendInitialMetadata initMeta , OpSendMessage body , OpSendCloseFromClient @@ -191,7 +221,7 @@ clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = -------------------------------------------------------------------------------- -- clientWriter (client side of client streaming mode) -type ClientWriterHandler = StreamSend -> Streaming () +type ClientWriterHandler = StreamSend ByteString -> Streaming () type ClientWriterResult = (Maybe ByteString, MetadataMap, MetadataMap, C.StatusCode, StatusDetails) @@ -209,9 +239,8 @@ clientWriterCmn :: Client -- ^ The active client -> ClientWriterHandler -> ClientCall -- ^ The active client call -> IO (Either GRPCIOError ClientWriterResult) -clientWriterCmn (clientCQ -> cq) initMeta f cc@(unClientCall -> c) = +clientWriterCmn (clientCQ -> cq) initMeta f (unsafeCC -> c) = runExceptT $ do - lift (debugClientCall cc) sendInitialMetadata c cq initMeta runStreamingProxy "clientWriterCmn" c cq (f streamSend) sendSingle c cq OpSendCloseFromClient @@ -231,7 +260,10 @@ pattern CWRFinal mmsg initMD trailMD st ds -- clientRW (client side of bidirectional streaming mode) -- | First parameter is initial server metadata. -type ClientRWHandler = MetadataMap -> StreamRecv -> StreamSend -> Streaming () +type ClientRWHandler = MetadataMap + -> StreamRecv ByteString + -> StreamSend ByteString + -> Streaming () -- | For bidirectional-streaming registered requests clientRW :: Client @@ -241,16 +273,15 @@ clientRW :: Client -- ^ request metadata -> ClientRWHandler -> IO (Either GRPCIOError (MetadataMap, C.StatusCode, StatusDetails)) -clientRW c@Client{ clientCQ = cq } rm tm initMeta f = - withClientCall c rm tm go +clientRW cl@(clientCQ -> cq) rm tm initMeta f = + withClientCall cl rm tm go where - go cc@(unClientCall -> call) = runExceptT $ do - lift (debugClientCall cc) - sendInitialMetadata call cq initMeta - srvMeta <- recvInitialMetadata call cq - runStreamingProxy "clientRW" call cq (f srvMeta streamRecv streamSend) - runOps' call cq [OpSendCloseFromClient] -- WritesDone() - recvStatusOnClient call cq -- Finish() + go (unsafeCC -> c) = runExceptT $ do + sendInitialMetadata c cq initMeta + srvMeta <- recvInitialMetadata c cq + runStreamingProxy "clientRW" c cq (f srvMeta streamRecv streamSend) + runOps' c cq [OpSendCloseFromClient] -- WritesDone() + recvStatusOnClient c cq -- Finish() -------------------------------------------------------------------------------- -- clientRequest (client side of normal request/response) @@ -265,15 +296,13 @@ clientRequest :: Client -> MetadataMap -- ^ Metadata to send with the request -> IO (Either GRPCIOError NormalRequestResult) -clientRequest c@Client{ clientCQ = cq } rm tm body initMeta = - withClientCall c rm tm (fmap join . go) +clientRequest cl@(clientCQ -> cq) rm tm body initMeta = + withClientCall cl rm tm (fmap join . go) where - go cc@(unClientCall -> call) = do - grpcDebug "clientRequest(R): created call." - debugClientCall cc + go (unsafeCC -> c) = -- NB: the send and receive operations below *must* be in separate -- batches, or the client hangs when the server can't be reached. - runOps call cq + runOps c cq [ OpSendInitialMetadata initMeta , OpSendMessage body , OpSendCloseFromClient @@ -283,7 +312,7 @@ clientRequest c@Client{ clientCQ = cq } rm tm body initMeta = grpcDebug "clientRequest(R) : batch error sending." return $ Left x Right rs -> - runOps call cq + runOps c cq [ OpRecvInitialMetadata , OpRecvMessage , OpRecvStatusOnClient diff --git a/src/Network/GRPC/LowLevel/Client/Unregistered.hs b/src/Network/GRPC/LowLevel/Client/Unregistered.hs index 0fe6da9..8800952 100644 --- a/src/Network/GRPC/LowLevel/Client/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Client/Unregistered.hs @@ -1,4 +1,5 @@ {-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ViewPatterns #-} module Network.GRPC.LowLevel.Client.Unregistered where @@ -59,15 +60,17 @@ clientRequest :: Client -> MetadataMap -- ^ Request metadata. -> IO (Either GRPCIOError NormalRequestResult) -clientRequest client@Client{..} meth timeLimit body meta = - fmap join $ withClientCall client meth timeLimit $ \call -> do - results <- runOps (unClientCall call) clientCQ - [ OpSendInitialMetadata meta - , OpSendMessage body - , OpSendCloseFromClient - , OpRecvInitialMetadata - , OpRecvMessage - , OpRecvStatusOnClient - ] - grpcDebug "clientRequest(U): ops ran." - return $ right compileNormalRequestResults results +clientRequest cl@(clientCQ -> cq) meth tm body initMeta = + join <$> withClientCall cl meth tm go + where + go (unsafeCC -> c) = do + results <- runOps c cq + [ OpSendInitialMetadata initMeta + , OpSendMessage body + , OpSendCloseFromClient + , OpRecvInitialMetadata + , OpRecvMessage + , OpRecvStatusOnClient + ] + grpcDebug "clientRequest(U): ops ran." + return $ right compileNormalRequestResults results diff --git a/src/Network/GRPC/LowLevel/CompletionQueue.hs b/src/Network/GRPC/LowLevel/CompletionQueue.hs index 9b436df..fbdc2b0 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue.hs @@ -16,6 +16,7 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} module Network.GRPC.LowLevel.CompletionQueue ( CompletionQueue @@ -40,7 +41,6 @@ import Control.Concurrent.STM.TVar (newTVarIO, readTVar, writeTVar) import Control.Exception (bracket) -import Control.Monad (liftM2) import Control.Monad.Managed import Control.Monad.Trans.Class (MonadTrans (lift)) import Control.Monad.Trans.Except @@ -120,7 +120,7 @@ shutdownCompletionQueue CompletionQueue{..} = do C.OpComplete -> drainLoop channelCreateCall :: C.Channel - -> (Maybe ServerCall) + -> Maybe (ServerCall a) -> C.PropagationMask -> CompletionQueue -> C.CallHandle @@ -129,7 +129,7 @@ channelCreateCall :: C.Channel channelCreateCall chan parent mask cq@CompletionQueue{..} handle deadline = withPermission Push cq $ do - let parentPtr = maybe (C.Call nullPtr) unServerCall parent + let parentPtr = maybe (C.Call nullPtr) unsafeSC parent grpcDebug $ "channelCreateCall: call with " ++ concat (intersperse " " [show chan, show parentPtr, show mask, @@ -140,56 +140,50 @@ channelCreateCall return $ Right $ ClientCall call -- | Create the call object to handle a registered call. -serverRequestCall :: C.Server - -> CompletionQueue - -> RegisteredMethod mt - -> IO (Either GRPCIOError ServerCall) -serverRequestCall s cq@CompletionQueue{.. } rm = +serverRequestCall :: RegisteredMethod mt + -> C.Server + -> CompletionQueue -- ^ server CQ + -> CompletionQueue -- ^ call CQ + -> IO (Either GRPCIOError (ServerCall (MethodPayload mt))) +serverRequestCall rm s scq ccq = -- NB: The method type dictates whether or not a payload is present, according -- to the payloadHandling function. We do not allocate a buffer for the -- payload when it is not present. - withPermission Push cq . with allocs $ \(dead, call, pay, meta) -> do - dbug "pre-pluck block" - withPermission Pluck cq $ do + withPermission Push scq . with allocs $ \(dead, call, pay, meta) -> + withPermission Pluck scq $ do md <- peek meta - tag <- newTag cq + tag <- newTag scq dbug $ "got pluck permission, registering call for tag=" ++ show tag - ce <- C.grpcServerRequestRegisteredCall s (methodHandle rm) call dead md pay unsafeCQ unsafeCQ tag + ce <- C.grpcServerRequestRegisteredCall s (methodHandle rm) call dead md + pay (unsafeCQ ccq) (unsafeCQ scq) tag runExceptT $ case ce of C.CallOk -> do ExceptT $ do - r <- pluck' cq tag Nothing + r <- pluck' scq tag Nothing dbug $ "pluck' finished:" ++ show r return r lift $ ServerCall <$> peek call + <*> return ccq <*> C.getAllMetadataArray md - <*> (if havePay then toBS pay else return Nothing) + <*> payload rm pay <*> convertDeadline dead _ -> do lift $ dbug $ "Throwing callError: " ++ show ce throwE (GRPCIOCallError ce) where - allocs = (,,,) <$> ptr <*> ptr <*> pay <*> md - where - md = managed C.withMetadataArrayPtr - pay = if havePay then managed C.withByteBufferPtr else return nullPtr - ptr :: forall a. Storable a => Managed (Ptr a) - ptr = managed (bracket malloc free) - dbug = grpcDebug . ("serverRequestCall(R): " ++) - havePay = payloadHandling (methodType rm) /= C.SrmPayloadNone - toBS p = peek p >>= \bb@(C.ByteBuffer rawPtr) -> - if | rawPtr == nullPtr -> return Nothing - | otherwise -> Just <$> C.copyByteBufferToByteString bb - convertDeadline deadline = do - deadline' <- C.timeSpec <$> peek deadline - -- On OS X, gRPC gives us a deadline that is just a delta, so we convert - -- it to an actual deadline. - if os == "darwin" - then do now <- getTime Monotonic - return $ now + deadline' - else return deadline' + allocs = (,,,) + <$> mgdPtr + <*> mgdPtr + <*> mgdPayload (methodType rm) + <*> managed C.withMetadataArrayPtr + dbug = grpcDebug . ("serverRequestCall(R): " ++) + -- On OS X, gRPC gives us a deadline that is just a delta, so we convert + -- it to an actual deadline. + convertDeadline (fmap C.timeSpec . peek -> d) + | os == "darwin" = (+) <$> d <*> getTime Monotonic + | otherwise = d -- | Register the server's completion queue. Must be done before the server is -- started. diff --git a/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs b/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs index 6aa29a6..e0215ff 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs @@ -5,7 +5,6 @@ module Network.GRPC.LowLevel.CompletionQueue.Internal where import Control.Concurrent.STM (atomically, retry) import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar, writeTVar) -import Control.Monad.IO.Class import Control.Exception (bracket) import Control.Monad import Data.IORef (IORef, atomicModifyIORef') @@ -62,6 +61,8 @@ data CompletionQueue = CompletionQueue {unsafeCQ :: C.CompletionQueue, -- items pushed onto the queue. } +instance Show CompletionQueue where show = show . unsafeCQ + type TimeoutSeconds = Int data CQOpType = Push | Pluck deriving (Show, Eq, Enum) diff --git a/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs b/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs index ce4818f..f217ba5 100644 --- a/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs @@ -1,10 +1,20 @@ -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} module Network.GRPC.LowLevel.CompletionQueue.Unregistered where import Control.Exception (bracket) +import Control.Monad.Managed +import Control.Monad.Trans.Class (MonadTrans (lift)) +import Control.Monad.Trans.Except import Foreign.Marshal.Alloc (free, malloc) -import Foreign.Storable (peek) +import Foreign.Ptr (Ptr) +import Foreign.Storable (Storable, peek) import Network.GRPC.LowLevel.Call import qualified Network.GRPC.LowLevel.Call.Unregistered as U import Network.GRPC.LowLevel.CompletionQueue.Internal @@ -30,47 +40,37 @@ channelCreateCall chan parent mask cq@CompletionQueue{..} meth endpt deadline = serverRequestCall :: C.Server - -> CompletionQueue + -> CompletionQueue -- ^ server CQ / notification CQ + -> CompletionQueue -- ^ call CQ -> IO (Either GRPCIOError U.ServerCall) -serverRequestCall server cq@CompletionQueue{..} = - withPermission Push cq $ - bracket malloc free $ \callPtr -> - C.withMetadataArrayPtr $ \metadataArrayPtr -> - C.withCallDetails $ \callDetails -> - withPermission Pluck cq $ do - grpcDebug $ "serverRequestCall: callPtr is " ++ show callPtr - metadataArray <- peek metadataArrayPtr - tag <- newTag cq - callError <- C.grpcServerRequestCall server callPtr callDetails - metadataArray unsafeCQ unsafeCQ tag - grpcDebug $ "serverRequestCall: callError was " ++ show callError - if callError /= C.CallOk - then do grpcDebug "serverRequestCall: got call error; cleaning up." - return $ Left $ GRPCIOCallError callError - else do pluckResult <- pluck cq tag Nothing - grpcDebug $ "serverRequestCall: pluckResult was " - ++ show pluckResult - case pluckResult of - Left x -> do - grpcDebug "serverRequestCall: pluck error." - return $ Left x - Right () -> do - rawCall <- peek callPtr - metadata <- C.getAllMetadataArray metadataArray - deadline <- getDeadline callDetails - method <- getMethod callDetails - host <- getHost callDetails - let call = U.ServerCall rawCall - metadata - Nothing - deadline - method - host - return $ Right call - - where getDeadline callDetails = do - C.timeSpec <$> C.callDetailsGetDeadline callDetails - getMethod callDetails = - MethodName <$> C.callDetailsGetMethod callDetails - getHost callDetails = - Host <$> C.callDetailsGetHost callDetails +serverRequestCall s scq ccq = + withPermission Push scq . with allocs $ \(call, meta, cd) -> + withPermission Pluck scq $ do + md <- peek meta + tag <- newTag scq + dbug $ "got pluck permission, registering call for tag=" ++ show tag + ce <- C.grpcServerRequestCall s call cd md (unsafeCQ ccq) (unsafeCQ scq) tag + runExceptT $ case ce of + C.CallOk -> do + ExceptT $ do + r <- pluck' scq tag Nothing + dbug $ "pluck' finished: " ++ show r + return r + lift $ + U.ServerCall + <$> peek call + <*> return ccq + <*> C.getAllMetadataArray md + <*> return Nothing + <*> (C.timeSpec <$> C.callDetailsGetDeadline cd) + <*> (MethodName <$> C.callDetailsGetMethod cd) + <*> (Host <$> C.callDetailsGetHost cd) + _ -> do + lift $ dbug $ "Throwing callError: " ++ show ce + throwE $ GRPCIOCallError ce + where + allocs = (,,) + <$> mgdPtr + <*> managed C.withMetadataArrayPtr + <*> managed C.withCallDetails + dbug = grpcDebug . ("serverRequestCall(U): " ++) diff --git a/src/Network/GRPC/LowLevel/GRPC.hs b/src/Network/GRPC/LowLevel/GRPC.hs index 6fdb448..e3f17f5 100644 --- a/src/Network/GRPC/LowLevel/GRPC.hs +++ b/src/Network/GRPC/LowLevel/GRPC.hs @@ -11,7 +11,7 @@ import qualified Data.ByteString as B import qualified Data.Map as M import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Op as C - +import Proto3.Wire.Decode (ParseError) #ifdef DEBUG import GHC.Conc (myThreadId) @@ -47,7 +47,7 @@ data GRPCIOError = GRPCIOCallError C.CallError -- reasonable amount of time. | GRPCIOUnknownError | GRPCIOBadStatusCode C.StatusCode StatusDetails - + | GRPCIODecodeError ParseError | GRPCIOInternalMissingExpectedPayload | GRPCIOInternalUnexpectedRecv String -- debugging description deriving (Show, Eq) diff --git a/src/Network/GRPC/LowLevel/Op.hs b/src/Network/GRPC/LowLevel/Op.hs index cd16016..616743a 100644 --- a/src/Network/GRPC/LowLevel/Op.hs +++ b/src/Network/GRPC/LowLevel/Op.hs @@ -345,12 +345,12 @@ streamingProxy nm c cq = maybe recv send run = lift . runOps c cq urecv = GRPCIOInternalUnexpectedRecv . (nm ++) -type StreamRecv = Streaming (Either GRPCIOError (Maybe ByteString)) -streamRecv :: StreamRecv +type StreamRecv a = Streaming (Either GRPCIOError (Maybe a)) +streamRecv :: StreamRecv ByteString streamRecv = P.request Nothing -type StreamSend = ByteString -> Streaming (Either GRPCIOError ()) -streamSend :: StreamSend +type StreamSend a = a -> Streaming (Either GRPCIOError ()) +streamSend :: StreamSend ByteString streamSend = fmap void . P.request . Just pattern RecvMsgRslt mmsg <- Right [OpRecvMessageResult mmsg] diff --git a/src/Network/GRPC/LowLevel/Server.hs b/src/Network/GRPC/LowLevel/Server.hs index 6fce45a..ff53736 100644 --- a/src/Network/GRPC/LowLevel/Server.hs +++ b/src/Network/GRPC/LowLevel/Server.hs @@ -1,21 +1,21 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} -- | This module defines data structures and operations pertaining to registered -- servers using registered calls; for unregistered support, see -- `Network.GRPC.LowLevel.Server.Unregistered`. module Network.GRPC.LowLevel.Server where -import Control.Arrow import Control.Exception (bracket, finally) import Control.Monad -import Control.Monad.Trans.Class (MonadTrans (lift)) import Control.Monad.Trans.Except import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call @@ -25,18 +25,18 @@ import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, serverRegisterCompletionQueue, serverRequestCall, serverShutdownAndNotify, - shutdownCompletionQueue) + shutdownCompletionQueue, + withCompletionQueue) import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.Op import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.ChannelArgs as C import qualified Network.GRPC.Unsafe.Op as C -import qualified Pipes as P -import qualified Pipes.Core as P -- | Wraps various gRPC state needed to run a server. data Server = Server - { internalServer :: C.Server + { serverGRPC :: GRPC + , unsafeServer :: C.Server , serverCQ :: CompletionQueue , normalMethods :: [RegisteredMethod 'Normal] , sstreamingMethods :: [RegisteredMethod 'ServerStreaming] @@ -52,9 +52,11 @@ data ServerConfig = ServerConfig -- used. Setting to "localhost" works fine in tests. , port :: Port -- ^ Port on which to listen for requests. - , methodsToRegister :: [(MethodName, GRPCMethodType)] - -- ^ List of (method name, method type) tuples specifying all methods to - -- register. + , methodsToRegisterNormal :: [MethodName] + -- ^ List of normal (non-streaming) methods to register. + , methodsToRegisterClientStreaming :: [MethodName] + , methodsToRegisterServerStreaming :: [MethodName] + , methodsToRegisterBiDiStreaming :: [MethodName] , serverArgs :: [C.Arg] -- ^ Optional arguments for setting up the channel on the server. Supplying an -- empty list will cause the channel to use gRPC's default options. @@ -73,47 +75,47 @@ startServer grpc conf@ServerConfig{..} = when (actualPort /= unPort port) $ error $ "Unable to bind port: " ++ show port cq <- createCompletionQueue grpc + grpcDebug $ "startServer: server CQ: " ++ show cq serverRegisterCompletionQueue server cq -- Register methods according to their GRPCMethodType kind. It's a bit ugly -- to partition them this way, but we get very convenient phantom typing -- elsewhere by doing so. - (ns, ss, cs, bs) <- do - let f (ns, ss, cs, bs) (nm, mt) = do - let reg = serverRegisterMethod server nm e mt - case mt of - Normal -> ( , ss, cs, bs) . (:ns) <$> reg - ServerStreaming -> (ns, , cs, bs) . (:ss) <$> reg - ClientStreaming -> (ns, ss, , bs) . (:cs) <$> reg - BiDiStreaming -> (ns, ss, cs, ) . (:bs) <$> reg - foldM f ([],[],[],[]) methodsToRegister - + -- TODO: change order of args so we can eta reduce. + ns <- mapM (\nm -> serverRegisterMethodNormal server nm e) + methodsToRegisterNormal + ss <- mapM (\nm -> serverRegisterMethodServerStreaming server nm e) + methodsToRegisterServerStreaming + cs <- mapM (\nm -> serverRegisterMethodClientStreaming server nm e) + methodsToRegisterClientStreaming + bs <- mapM (\nm -> serverRegisterMethodBiDiStreaming server nm e) + methodsToRegisterBiDiStreaming C.grpcServerStart server - return $ Server server cq ns ss cs bs conf + return $ Server grpc server cq ns ss cs bs conf stopServer :: Server -> IO () -- TODO: Do method handles need to be freed? -stopServer Server{..} = do +stopServer Server{ unsafeServer = s, serverCQ = scq } = do grpcDebug "stopServer: calling shutdownNotify." shutdownNotify grpcDebug "stopServer: cancelling all calls." - C.grpcServerCancelAllCalls internalServer + C.grpcServerCancelAllCalls s grpcDebug "stopServer: call grpc_server_destroy." - C.grpcServerDestroy internalServer + C.grpcServerDestroy s grpcDebug "stopServer: shutting down CQ." shutdownCQ where shutdownCQ = do - shutdownResult <- shutdownCompletionQueue serverCQ + shutdownResult <- shutdownCompletionQueue scq case shutdownResult of Left _ -> do putStrLn "Warning: completion queue didn't shut down." putStrLn "Trying to stop server anyway." Right _ -> return () shutdownNotify = do let shutdownTag = C.tag 0 - serverShutdownAndNotify internalServer serverCQ shutdownTag + serverShutdownAndNotify s scq shutdownTag grpcDebug "called serverShutdownAndNotify; plucking." - shutdownEvent <- pluck serverCQ shutdownTag (Just 30) + shutdownEvent <- pluck scq shutdownTag (Just 30) grpcDebug $ "shutdownNotify: got shutdown event" ++ show shutdownEvent case shutdownEvent of -- This case occurs when we pluck but the queue is already in the @@ -126,11 +128,42 @@ stopServer Server{..} = do withServer :: GRPC -> ServerConfig -> (Server -> IO a) -> IO a withServer grpc cfg = bracket (startServer grpc cfg) stopServer +-- | Less precisely-typed registration function used in +-- 'serverRegisterMethodNormal', 'serverRegisterMethodServerStreaming', +-- 'serverRegisterMethodClientStreaming', and +-- 'serverRegisterMethodBiDiStreaming'. +serverRegisterMethod :: C.Server + -> MethodName + -> Endpoint + -> GRPCMethodType + -> IO (C.CallHandle) +serverRegisterMethod s nm e mty = + C.grpcServerRegisterMethod s + (unMethodName nm) + (unEndpoint e) + (payloadHandling mty) + +{- +TODO: Consolidate the register functions below. + +It seems like we'd need true dependent types to use only one + registration function. Ideally we'd want a type like + serverRegisterMethod :: C.Server + -> MethodName + -> Endpoint + -> (t :: GRPCMethodType) + -> IO (RegisteredMethod (Lifted t)) + +where `Lifted t` is the type in the t data kind that corresponds to the data +constructor t the function was given. + +-} + -- | Register a method on a server. The 'RegisteredMethod' type can then be used -- to wait for a request to arrive. Note: gRPC claims this must be called before -- the server is started, so we do it during startup according to the -- 'ServerConfig'. -serverRegisterMethod :: C.Server +serverRegisterMethodNormal :: C.Server -> MethodName -- ^ method name, e.g. "/foo" -> Endpoint @@ -139,42 +172,86 @@ serverRegisterMethod :: C.Server -- parameters to start a server in the first place. It -- doesn't seem to have any effect, even if it's filled -- with nonsense. - -> GRPCMethodType - -- ^ Type of method this will be. In the future, this will - -- be used to switch to the correct handling logic. - -> IO (RegisteredMethod mt) -serverRegisterMethod internalServer meth e mty = - RegisteredMethod mty meth e <$> do - h <- C.grpcServerRegisterMethod internalServer - (unMethodName meth) (unEndpoint e) (payloadHandling mty) - grpcDebug $ "registered method handle: " ++ show h ++ " of type " ++ show mty - return h + -> IO (RegisteredMethod 'Normal) +serverRegisterMethodNormal internalServer meth e = do + h <- serverRegisterMethod internalServer meth e Normal + return $ RegisteredMethodNormal meth e h + +serverRegisterMethodClientStreaming + :: C.Server + -> MethodName + -- ^ method name, e.g. "/foo" + -> Endpoint + -- ^ Endpoint name name, e.g. "localhost:9999". I have no + -- idea why this is needed since we have to provide these + -- parameters to start a server in the first place. It + -- doesn't seem to have any effect, even if it's filled + -- with nonsense. + -> IO (RegisteredMethod 'ClientStreaming) +serverRegisterMethodClientStreaming internalServer meth e = do + h <- serverRegisterMethod internalServer meth e ClientStreaming + return $ RegisteredMethodClientStreaming meth e h + + +serverRegisterMethodServerStreaming + :: C.Server + -> MethodName + -- ^ method name, e.g. "/foo" + -> Endpoint + -- ^ Endpoint name name, e.g. "localhost:9999". I have no + -- idea why this is needed since we have to provide these + -- parameters to start a server in the first place. It + -- doesn't seem to have any effect, even if it's filled + -- with nonsense. + -> IO (RegisteredMethod 'ServerStreaming) +serverRegisterMethodServerStreaming internalServer meth e = do + h <- serverRegisterMethod internalServer meth e ServerStreaming + return $ RegisteredMethodServerStreaming meth e h + + +serverRegisterMethodBiDiStreaming + :: C.Server + -> MethodName + -- ^ method name, e.g. "/foo" + -> Endpoint + -- ^ Endpoint name name, e.g. "localhost:9999". I have no + -- idea why this is needed since we have to provide these + -- parameters to start a server in the first place. It + -- doesn't seem to have any effect, even if it's filled + -- with nonsense. + -> IO (RegisteredMethod 'BiDiStreaming) +serverRegisterMethodBiDiStreaming internalServer meth e = do + h <- serverRegisterMethod internalServer meth e BiDiStreaming + return $ RegisteredMethodBiDiStreaming meth e h -- | Create a 'Call' with which to wait for the invocation of a registered -- method. serverCreateCall :: Server -> RegisteredMethod mt - -> IO (Either GRPCIOError ServerCall) -serverCreateCall Server{..} = serverRequestCall internalServer serverCQ + -> CompletionQueue -- ^ call CQ + -> IO (Either GRPCIOError (ServerCall (MethodPayload mt))) +serverCreateCall Server{..} rm = serverRequestCall rm unsafeServer serverCQ withServerCall :: Server -> RegisteredMethod mt - -> (ServerCall -> IO (Either GRPCIOError a)) + -> (ServerCall (MethodPayload mt) -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) -withServerCall server regmethod f = do - createResult <- serverCreateCall server regmethod - case createResult of - Left x -> return $ Left x - Right call -> f call `finally` logDestroy call - where logDestroy c = grpcDebug "withServerRegisteredCall: destroying." - >> destroyServerCall c +withServerCall s rm f = + withCompletionQueue (serverGRPC s) $ + serverCreateCall s rm >=> \case + Left e -> return (Left e) + Right c -> do + debugServerCall c + f c `finally` do + grpcDebug "withServerCall(R): destroying." + destroyServerCall c -------------------------------------------------------------------------------- -- serverReader (server side of client streaming mode) type ServerReaderHandler - = ServerCall - -> StreamRecv + = ServerCall () + -> StreamRecv ByteString -> Streaming (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails) serverReader :: Server @@ -182,24 +259,23 @@ serverReader :: Server -> MetadataMap -- ^ initial server metadata -> ServerReaderHandler -> IO (Either GRPCIOError ()) -serverReader s@Server{ serverCQ = cq } rm initMeta f = withServerCall s rm go +serverReader s rm initMeta f = withServerCall s rm go where - go sc@(unServerCall -> c) = runExceptT $ do - lift $ debugServerCall sc - (mmsg, trailMD, st, ds) <- - runStreamingProxy "serverReader" c cq (f sc streamRecv) - runOps' c cq ( OpSendInitialMetadata initMeta - : OpSendStatusFromServer trailMD st ds - : maybe [] ((:[]) . OpSendMessage) mmsg - ) + go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = runExceptT $ do + (mmsg, trailMeta, st, ds) <- + runStreamingProxy "serverReader" c ccq (f sc streamRecv) + runOps' c ccq ( OpSendInitialMetadata initMeta + : OpSendStatusFromServer trailMeta st ds + : maybe [] ((:[]) . OpSendMessage) mmsg + ) return () -------------------------------------------------------------------------------- -- serverWriter (server side of server streaming mode) type ServerWriterHandler - = ServerCall - -> StreamSend + = ServerCall ByteString + -> StreamSend ByteString -> Streaming (MetadataMap, C.StatusCode, StatusDetails) -- | Wait for and then handle a registered, server-streaming call. @@ -209,21 +285,20 @@ serverWriter :: Server -- ^ Initial server metadata -> ServerWriterHandler -> IO (Either GRPCIOError ()) -serverWriter s@Server{ serverCQ = cq } rm initMeta f = withServerCall s rm go +serverWriter s rm initMeta f = withServerCall s rm go where - go sc@ServerCall{ unServerCall = c } = runExceptT $ do - lift (debugServerCall sc) - sendInitialMetadata c cq initMeta - st <- runStreamingProxy "serverWriter" c cq (f sc streamSend) - sendStatusFromServer c cq st + go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = runExceptT $ do + sendInitialMetadata c ccq initMeta + st <- runStreamingProxy "serverWriter" c ccq (f sc streamSend) + sendStatusFromServer c ccq st -------------------------------------------------------------------------------- -- serverRW (server side of bidirectional streaming mode) type ServerRWHandler - = ServerCall - -> StreamRecv - -> StreamSend + = ServerCall () + -> StreamRecv ByteString + -> StreamSend ByteString -> Streaming (MetadataMap, C.StatusCode, StatusDetails) serverRW :: Server @@ -232,13 +307,12 @@ serverRW :: Server -- ^ initial server metadata -> ServerRWHandler -> IO (Either GRPCIOError ()) -serverRW s@Server{ serverCQ = cq } rm initMeta f = withServerCall s rm go +serverRW s rm initMeta f = withServerCall s rm go where - go sc@(unServerCall -> c) = runExceptT $ do - lift $ debugServerCall sc - sendInitialMetadata c cq initMeta - st <- runStreamingProxy "serverRW" c cq (f sc streamRecv streamSend) - sendStatusFromServer c cq st + go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = runExceptT $ do + sendInitialMetadata c ccq initMeta + st <- runStreamingProxy "serverRW" c ccq (f sc streamRecv streamSend) + sendStatusFromServer c ccq st -------------------------------------------------------------------------------- -- serverHandleNormalCall (server side of normal request/response) @@ -250,7 +324,9 @@ serverRW s@Server{ serverCQ = cq } rm initMeta f = withServerCall s rm go -- respectively. We pass in the 'ServerCall' so that the server can call -- 'serverCallCancel' on it if needed. type ServerHandler - = ServerCall -> ByteString -> MetadataMap + = ServerCall ByteString + -> ByteString + -> MetadataMap -> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) -- | Wait for and then handle a normal (non-streaming) call. @@ -260,21 +336,15 @@ serverHandleNormalCall :: Server -- ^ Initial server metadata -> ServerHandler -> IO (Either GRPCIOError ()) -serverHandleNormalCall s@Server{ serverCQ = cq } rm initMeta f = +serverHandleNormalCall s rm initMeta f = withServerCall s rm go where - go sc@(unServerCall -> call) = do - grpcDebug "serverHandleNormalCall(R): starting batch." - debugServerCall sc - case optionalPayload sc of - Nothing -> return (Left GRPCIOInternalMissingExpectedPayload) - Just pay -> do - (rspBody, trailMeta, status, ds) <- f sc pay (requestMetadataRecv sc) - eea <- runOps call cq - [ OpSendInitialMetadata initMeta - , OpRecvCloseOnServer - , OpSendMessage rspBody - , OpSendStatusFromServer trailMeta status ds - ] - <* grpcDebug "serverHandleNormalCall(R): finished response ops." - return (void eea) + go sc@ServerCall{..} = do + (rsp, trailMeta, st, ds) <- f sc optionalPayload requestMetadataRecv + void <$> runOps unsafeSC callCQ + [ OpSendInitialMetadata initMeta + , OpRecvCloseOnServer + , OpSendMessage rsp + , OpSendStatusFromServer trailMeta st ds + ] + <* grpcDebug "serverHandleNormalCall(R): finished response ops." diff --git a/src/Network/GRPC/LowLevel/Server/Unregistered.hs b/src/Network/GRPC/LowLevel/Server/Unregistered.hs index 103dbcb..51445dc 100644 --- a/src/Network/GRPC/LowLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/LowLevel/Server/Unregistered.hs @@ -4,8 +4,10 @@ module Network.GRPC.LowLevel.Server.Unregistered where import Control.Exception (finally) +import Control.Monad import Data.ByteString (ByteString) import Network.GRPC.LowLevel.Call.Unregistered +import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, withCompletionQueue) import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) import Network.GRPC.LowLevel.GRPC import Network.GRPC.LowLevel.Op (Op (..), OpRecvResult (..), @@ -13,20 +15,21 @@ import Network.GRPC.LowLevel.Op (Op (..), Op import Network.GRPC.LowLevel.Server (Server (..)) import qualified Network.GRPC.Unsafe.Op as C -serverCreateCall :: Server -> IO (Either GRPCIOError ServerCall) -serverCreateCall Server{..} = - serverRequestCall internalServer serverCQ +serverCreateCall :: Server + -> CompletionQueue -- ^ call CQ + -> IO (Either GRPCIOError ServerCall) +serverCreateCall Server{..} = serverRequestCall unsafeServer serverCQ withServerCall :: Server - -> (ServerCall -> IO (Either GRPCIOError a)) - -> IO (Either GRPCIOError a) -withServerCall server f = do - createResult <- serverCreateCall server - case createResult of - Left x -> return $ Left x - Right call -> f call `finally` logDestroy call - where logDestroy c = grpcDebug "withServerCall: destroying." - >> destroyServerCall c + -> (ServerCall -> IO (Either GRPCIOError a)) + -> IO (Either GRPCIOError a) +withServerCall s f = + withCompletionQueue (serverGRPC s) $ + serverCreateCall s >=> \case + Left e -> return (Left e) + Right c -> f c `finally` do + grpcDebug "withServerCall: destroying." + destroyServerCall c -- | Sequence of 'Op's needed to receive a normal (non-streaming) call. -- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, or else the @@ -51,7 +54,8 @@ serverOpsSendNormalResponse body metadata code details = -- | A handler for an unregistered server call; bytestring arguments are the -- request body and response body respectively. type ServerHandler - = ServerCall -> ByteString + = ServerCall + -> ByteString -> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) -- | Handle one unregistered call. @@ -59,31 +63,32 @@ serverHandleNormalCall :: Server -> MetadataMap -- ^ Initial server metadata. -> ServerHandler -> IO (Either GRPCIOError ()) -serverHandleNormalCall s@Server{..} srvMetadata f = - withServerCall s $ \call@ServerCall{..} -> do - grpcDebug "serverHandleNormalCall(U): starting batch." - runOps unServerCall serverCQ - [ OpSendInitialMetadata srvMetadata - , OpRecvMessage - ] - >>= \case - Left x -> do - grpcDebug "serverHandleNormalCall(U): ops failed; aborting" - return $ Left x - Right [OpRecvMessageResult (Just body)] -> do - grpcDebug $ "got client metadata: " ++ show requestMetadataRecv - grpcDebug $ "call_details host is: " ++ show callHost - (rspBody, rspMeta, status, ds) <- f call body - runOps unServerCall serverCQ - [ OpRecvCloseOnServer - , OpSendMessage rspBody, - OpSendStatusFromServer rspMeta status ds - ] - >>= \case - Left x -> do - grpcDebug "serverHandleNormalCall(U): resp failed." - return $ Left x - Right _ -> do - grpcDebug "serverHandleNormalCall(U): ops done." - return $ Right () - x -> error $ "impossible pattern match: " ++ show x +serverHandleNormalCall s initMeta f = withServerCall s go + where + go sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } = do + grpcDebug "serverHandleNormalCall(U): starting batch." + runOps c cq + [ OpSendInitialMetadata initMeta + , OpRecvMessage + ] + >>= \case + Left x -> do + grpcDebug "serverHandleNormalCall(U): ops failed; aborting" + return $ Left x + Right [OpRecvMessageResult (Just body)] -> do + grpcDebug $ "got client metadata: " ++ show requestMetadataRecv + grpcDebug $ "call_details host is: " ++ show callHost + (rsp, trailMeta, st, ds) <- f sc body + runOps c cq + [ OpRecvCloseOnServer + , OpSendMessage rsp, + OpSendStatusFromServer trailMeta st ds + ] + >>= \case + Left x -> do + grpcDebug "serverHandleNormalCall(U): resp failed." + return $ Left x + Right _ -> do + grpcDebug "serverHandleNormalCall(U): ops done." + return $ Right () + x -> error $ "impossible pattern match: " ++ show x diff --git a/stack.yaml b/stack.yaml index 1d865d1..b75ca7f 100644 --- a/stack.yaml +++ b/stack.yaml @@ -7,6 +7,14 @@ resolver: lts-5.10 # Local packages, usually specified by relative directory name packages: - '.' +- location: + git: git@github.mv.awakenetworks.net:awakenetworks/protobuf-wire.git + commit: e5a6985eeb5eb1eded7b46b2892874000e2ae835 + extra-dep: true +- location: + git: git@github.mv.awakenetworks.net:awakenetworks/proto3-wire.git + commit: 9898a793ab61fd582b5b7172d349ffd248b20095 + extra-dep: true # Packages to be pulled from upstream that are not in the resolver (e.g., acme-missiles-0.3) extra-deps: [managed-1.0.5] diff --git a/tests/LowLevelTests.hs b/tests/LowLevelTests.hs index 92e8de8..7271be5 100644 --- a/tests/LowLevelTests.hs +++ b/tests/LowLevelTests.hs @@ -70,24 +70,24 @@ testClientCreateDestroy = testClientTimeoutNoServer :: TestTree testClientTimeoutNoServer = clientOnlyTest "request timeout when server DNE" $ \c -> do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" r <- clientRequest c rm 1 "Hello" mempty r @?= Left GRPCIOTimeout testServerCreateDestroy :: TestTree testServerCreateDestroy = - serverOnlyTest "start/stop" [] nop + serverOnlyTest "start/stop" (["/foo"],[],[],[]) nop testMixRegisteredUnregistered :: TestTree testMixRegisteredUnregistered = csTest "server uses unregistered calls to handle unknown endpoints" client server - [("/foo", Normal)] + (["/foo"],[],[],[]) where client c = do - rm1 <- clientRegisterMethod c "/foo" Normal - rm2 <- clientRegisterMethod c "/bar" Normal + rm1 <- clientRegisterMethodNormal c "/foo" + rm2 <- clientRegisterMethodNormal c "/bar" clientRequest c rm1 1 "Hello" mempty >>= do checkReqRslt $ \NormalRequestResult{..} -> do rspBody @?= "reply test" @@ -121,11 +121,11 @@ testMixRegisteredUnregistered = -- tweak EH behavior / async use. testPayload :: TestTree testPayload = - csTest "registered normal request/response" client server [("/foo", Normal)] + csTest "registered normal request/response" client server (["/foo"],[],[],[]) where clientMD = [("foo_key", "foo_val"), ("bar_key", "bar_val")] client c = do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 10 "Hello!" clientMD >>= do checkReqRslt $ \NormalRequestResult{..} -> do rspCode @?= StatusOk @@ -143,10 +143,10 @@ testPayload = testServerCancel :: TestTree testServerCancel = - csTest "server cancel call" client server [("/foo", Normal)] + csTest "server cancel call" client server (["/foo"],[],[],[]) where client c = do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" res <- clientRequest c rm 10 "" mempty res @?= badStatus StatusCancelled server s = do @@ -158,7 +158,7 @@ testServerCancel = testServerStreaming :: TestTree testServerStreaming = - csTest "server streaming" client server [("/feed", ServerStreaming)] + csTest "server streaming" client server ([],[],["/feed"],[]) where clientInitMD = [("client","initmd")] serverInitMD = [("server","initmd")] @@ -166,7 +166,7 @@ testServerStreaming = pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString] client c = do - rm <- clientRegisterMethod c "/feed" ServerStreaming + rm <- clientRegisterMethodServerStreaming c "/feed" eea <- clientReader c rm 10 clientPay clientInitMD $ \initMD recv -> do liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD forM_ pays $ \p -> recv `is` Right (Just p) @@ -175,20 +175,18 @@ testServerStreaming = server s = do let rm = head (sstreamingMethods s) - eea <- serverWriter s rm serverInitMD $ \sc send -> do + r <- serverWriter s rm serverInitMD $ \sc send -> do liftIO $ do - checkMD "Client request metadata mismatch" + checkMD "Server request metadata mismatch" clientInitMD (requestMetadataRecv sc) - case optionalPayload sc of - Nothing -> assertFailure "expected optional payload" - Just pay -> pay @?= clientPay + optionalPayload sc @?= clientPay forM_ pays $ \p -> send p `is` Right () return (dummyMeta, StatusOk, "dtls") - eea @?= Right () + r @?= Right () testClientStreaming :: TestTree testClientStreaming = - csTest "client streaming" client server [("/slurp", ClientStreaming)] + csTest "client streaming" client server ([],["/slurp"],[],[]) where clientInitMD = [("a","b")] serverInitMD = [("x","y")] @@ -199,7 +197,7 @@ testClientStreaming = pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString] client c = do - rm <- clientRegisterMethod c "/slurp" ClientStreaming + rm <- clientRegisterMethodClientStreaming c "/slurp" eea <- clientWriter c rm 10 clientInitMD $ \send -> do -- liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD forM_ pays $ \p -> send p `is` Right () @@ -217,19 +215,18 @@ testClientStreaming = testBiDiStreaming :: TestTree testBiDiStreaming = - csTest "bidirectional streaming" client server [("/bidi", BiDiStreaming)] + csTest "bidirectional streaming" client server ([],[],[],["/bidi"]) where clientInitMD = [("bidi-streaming","client")] serverInitMD = [("bidi-streaming","server")] trailMD = dummyMeta serverStatus = StatusOk serverDtls = "deets" + is act x = act >>= liftIO . (@?= x) client c = do - rm <- clientRegisterMethod c "/bidi" BiDiStreaming + rm <- clientRegisterMethodBiDiStreaming c "/bidi" eea <- clientRW c rm 10 clientInitMD $ \initMD recv send -> do - liftIO $ checkMD "Server initial metadata mismatch" - serverInitMD initMD send "cw0" `is` Right () recv `is` Right (Just "sw0") send "cw1" `is` Right () @@ -263,13 +260,13 @@ testClientCall = testServerCall :: TestTree testServerCall = - serverOnlyTest "create/destroy call" [] $ \s -> do + serverOnlyTest "create/destroy call" ([],[],[],[]) $ \s -> do r <- U.withServerCall s $ const $ return $ Right () r @?= Left GRPCIOTimeout testPayloadUnregistered :: TestTree testPayloadUnregistered = - csTest "unregistered normal request/response" client server [] + csTest "unregistered normal request/response" client server ([],[],[],[]) where client c = U.clientRequest c "/foo" 10 "Hello!" mempty >>= do @@ -289,10 +286,10 @@ testGoaway = csTest "Client handles server shutdown gracefully" client server - [("/foo", Normal)] + (["/foo"],[],[],[]) where client c = do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 10 "" mempty clientRequest c rm 10 "" mempty lastResult <- clientRequest c rm 1 "" mempty @@ -310,10 +307,10 @@ testGoaway = testSlowServer :: TestTree testSlowServer = - csTest "Client handles slow server response" client server [("/foo", Normal)] + csTest "Client handles slow server response" client server (["/foo"],[],[],[]) where client c = do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" result <- clientRequest c rm 1 "" mempty result @?= badStatus StatusDeadlineExceeded server s = do @@ -325,10 +322,10 @@ testSlowServer = testServerCallExpirationCheck :: TestTree testServerCallExpirationCheck = - csTest "Check for call expiration" client server [("/foo", Normal)] + csTest "Check for call expiration" client server (["/foo"],[],[],[]) where client c = do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" result <- clientRequest c rm 3 "" mempty return () server s = do @@ -352,10 +349,10 @@ testCustomUserAgent = clientArgs = [UserAgentPrefix "prefix!", UserAgentSuffix "suffix!"] client = TestClient (ClientConfig "localhost" 50051 clientArgs) $ - \c -> do rm <- clientRegisterMethod c "/foo" Normal + \c -> do rm <- clientRegisterMethodNormal c "/foo" result <- clientRequest c rm 4 "" mempty return () - server = TestServer (stdServerConf [("/foo", Normal)]) $ \s -> do + server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \_ _ meta -> do let ua = meta M.! "user-agent" @@ -373,10 +370,10 @@ testClientCompression = "localhost" 50051 [CompressionAlgArg GrpcCompressDeflate]) $ \c -> do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" result <- clientRequest c rm 1 "hello" mempty return () - server = TestServer (stdServerConf [("/foo", Normal)]) $ \s -> do + server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \_ body _ -> do body @?= "hello" @@ -391,7 +388,7 @@ testClientServerCompression = 50051 [CompressionAlgArg GrpcCompressDeflate] client = TestClient cconf $ \c -> do - rm <- clientRegisterMethod c "/foo" Normal + rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 1 "hello" mempty >>= do checkReqRslt $ \NormalRequestResult{..} -> do rspCode @?= StatusOk @@ -402,7 +399,7 @@ testClientServerCompression = return () sconf = ServerConfig "localhost" 50051 - [("/foo", Normal)] + ["/foo"] [] [] [] [CompressionAlgArg GrpcCompressDeflate] server = TestServer sconf $ \s -> do let rm = head (normalMethods s) @@ -423,7 +420,7 @@ dummyMeta = [("foo","bar")] dummyResp :: (ByteString, MetadataMap, StatusCode, StatusDetails) dummyResp = ("", mempty, StatusOk, StatusDetails "") -dummyHandler :: ServerCall -> ByteString -> MetadataMap +dummyHandler :: ServerCall a -> ByteString -> MetadataMap -> IO (ByteString, MetadataMap, StatusCode, StatusDetails) dummyHandler _ _ _ = return dummyResp @@ -441,11 +438,11 @@ nop :: Monad m => a -> m () nop = const (return ()) serverOnlyTest :: TestName - -> [(MethodName, GRPCMethodType)] + -> ([MethodName],[MethodName],[MethodName],[MethodName]) -> (Server -> IO ()) -> TestTree serverOnlyTest nm ms = - testCase ("Server - " ++ nm) . runTestServer . stdTestServer ms + testCase ("Server - " ++ nm) . runTestServer . TestServer (serverConf ms) clientOnlyTest :: TestName -> (Client -> IO ()) -> TestTree clientOnlyTest nm = @@ -454,9 +451,10 @@ clientOnlyTest nm = csTest :: TestName -> (Client -> IO ()) -> (Server -> IO ()) - -> [(MethodName, GRPCMethodType)] + -> ([MethodName],[MethodName],[MethodName],[MethodName]) -> TestTree -csTest nm c s ms = csTest' nm (stdTestClient c) (stdTestServer ms s) +csTest nm c s ms = + csTest' nm (stdTestClient c) (TestServer (serverConf ms) s) csTest' :: TestName -> TestClient -> TestServer -> TestTree csTest' nm tc ts = @@ -505,11 +503,16 @@ runTestServer :: TestServer -> IO () runTestServer (TestServer conf f) = runManaged $ mgdGRPC >>= mgdServer conf >>= liftIO . f -stdTestServer :: [(MethodName, GRPCMethodType)] -> (Server -> IO ()) -> TestServer -stdTestServer = TestServer . stdServerConf +defServerConf :: ServerConfig +defServerConf = ServerConfig "localhost" 50051 [] [] [] [] [] -stdServerConf :: [(MethodName, GRPCMethodType)] -> ServerConfig -stdServerConf xs = ServerConfig "localhost" 50051 xs [] +serverConf :: ([MethodName],[MethodName],[MethodName],[MethodName]) + -> ServerConfig +serverConf (ns, cs, ss, bs) = + defServerConf {methodsToRegisterNormal = ns, + methodsToRegisterClientStreaming = cs, + methodsToRegisterServerStreaming = ss, + methodsToRegisterBiDiStreaming = bs} threadDelaySecs :: Int -> IO () threadDelaySecs = threadDelay . (* 10^(6::Int)) diff --git a/tests/LowLevelTests/Op.hs b/tests/LowLevelTests/Op.hs index 559aa7e..275b5ed 100644 --- a/tests/LowLevelTests/Op.hs +++ b/tests/LowLevelTests/Op.hs @@ -4,9 +4,7 @@ module LowLevelTests.Op where -import Control.Concurrent (threadDelay) -import Data.ByteString (isPrefixOf) -import Foreign.Storable (peek) +import Data.ByteString (ByteString, isPrefixOf) import Test.Tasty import Test.Tasty.HUnit as HU (testCase, (@?=), assertBool) @@ -16,7 +14,6 @@ import Network.GRPC.LowLevel.Call import Network.GRPC.LowLevel.Client import Network.GRPC.LowLevel.Server import Network.GRPC.LowLevel.Op -import Network.GRPC.LowLevel.CompletionQueue lowLevelOpTests :: TestTree lowLevelOpTests = testGroup "Synchronous unit tests of low-level Op interface" @@ -29,7 +26,7 @@ testCancelFromServer = withClientServerUnaryCall grpc $ \(Client{..}, Server{..}, ClientCall{..}, sc@ServerCall{..}) -> do serverCallCancel sc StatusPermissionDenied "TestStatus" - clientRes <- runOps unClientCall clientCQ clientRecvOps + clientRes <- runOps unsafeCC clientCQ clientRecvOps case clientRes of Left x -> error $ "Client recv error: " ++ show x Right [_,_,OpRecvStatusOnClientResult _ code details] -> do @@ -48,12 +45,13 @@ runSerialTest f = Right () -> return () withClientServerUnaryCall :: GRPC - -> ((Client, Server, ClientCall, ServerCall) + -> ((Client, Server, ClientCall, + ServerCall ByteString) -> IO (Either GRPCIOError a)) -> IO (Either GRPCIOError a) withClientServerUnaryCall grpc f = do withClient grpc clientConf $ \c -> do - crm <- clientRegisterMethod c "/foo" Normal + crm <- clientRegisterMethodNormal c "/foo" withServer grpc serverConf $ \s -> withClientCall c crm 10 $ \cc -> do let srm = head (normalMethods s) @@ -61,12 +59,12 @@ withClientServerUnaryCall grpc f = do -- because registered methods try to do recv ops immediately when -- created. If later we want to send payloads or metadata, we'll need -- to tweak this. - _clientRes <- runOps (unClientCall cc) (clientCQ c) clientEmptySendOps + _clientRes <- runOps (unsafeCC cc) (clientCQ c) clientEmptySendOps withServerCall s srm $ \sc -> f (c, s, cc, sc) serverConf :: ServerConfig -serverConf = ServerConfig "localhost" 50051 [("/foo", Normal)] [] +serverConf = ServerConfig "localhost" 50051 [("/foo")] [] [] [] [] clientConf :: ClientConfig clientConf = ClientConfig "localhost" 50051 []