diff options
author | Mark D. Roth <roth@google.com> | 2016-06-16 10:39:39 -0700 |
---|---|---|
committer | Mark D. Roth <roth@google.com> | 2016-06-16 10:39:39 -0700 |
commit | c008b33c18f1d2c49ee1b6683c2d13a85e2c2432 (patch) | |
tree | 7ae8188383c268aee1cdeb4fb47ca2110f1542ec | |
parent | 8a1a5976c7eced718e2385b7f36046d6e73017af (diff) |
Pass channel args to ChannelData ctor and ChannelData to CallData ctor.
-rw-r--r-- | include/grpc++/channel_filter.h | 79 | ||||
-rw-r--r-- | src/cpp/common/channel_filter.cc | 2 | ||||
-rw-r--r-- | test/cpp/end2end/filter_end2end_test.cc | 15 |
3 files changed, 49 insertions, 47 deletions
diff --git a/include/grpc++/channel_filter.h b/include/grpc++/channel_filter.h index b37d986a5a..8731a5e965 100644 --- a/include/grpc++/channel_filter.h +++ b/include/grpc++/channel_filter.h @@ -53,6 +53,20 @@ namespace grpc { +// Represents channel data. +// Note: Must be copyable. +class ChannelData { + public: + virtual ~ChannelData() {} + + virtual void StartTransportOp( + grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, + grpc_transport_op *op); + + protected: + explicit ChannelData(const grpc_channel_args&) {} +}; + // Represents call data. // Note: Must be copyable. class CallData { @@ -70,21 +84,7 @@ class CallData { virtual char* GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem); protected: - CallData() {} -}; - -// Represents channel data. -// Note: Must be copyable. -class ChannelData { - public: - virtual ~ChannelData() {} - - virtual void StartTransportOp( - grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, - grpc_transport_op *op); - - protected: - ChannelData() {} + explicit CallData(const ChannelData&) {} }; namespace internal { @@ -93,13 +93,35 @@ namespace internal { template<typename ChannelDataType, typename CallDataType> class ChannelFilter { public: + static const size_t channel_data_size = sizeof(ChannelDataType); + + static void InitChannelElement( + grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, + grpc_channel_element_args *args) { + // Construct the object in the already-allocated memory. + new (elem->channel_data) ChannelDataType(*args->channel_args); + } + + static void DestroyChannelElement( + grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) { + reinterpret_cast<ChannelDataType*>(elem->channel_data)->~ChannelDataType(); + } + + static void StartTransportOp( + grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, + grpc_transport_op *op) { + ChannelDataType* channel_data = (ChannelDataType*)elem->channel_data; + channel_data->StartTransportOp(exec_ctx, elem, op); + } + static const size_t call_data_size = sizeof(CallDataType); static void InitCallElement( grpc_exec_ctx *exec_ctx, grpc_call_element *elem, grpc_call_element_args *args) { + const ChannelDataType& channel_data = *(ChannelDataType*)elem->channel_data; // Construct the object in the already-allocated memory. - new (elem->call_data) CallDataType(); + new (elem->call_data) CallDataType(channel_data); } static void DestroyCallElement( @@ -127,33 +149,12 @@ class ChannelFilter { CallDataType* call_data = (CallDataType*)elem->call_data; return call_data->GetPeer(exec_ctx, elem); } - - static const size_t channel_data_size = sizeof(ChannelDataType); - - static void InitChannelElement( - grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, - grpc_channel_element_args *args) { - // Construct the object in the already-allocated memory. - new (elem->channel_data) ChannelDataType(); - } - - static void DestroyChannelElement( - grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) { - reinterpret_cast<ChannelDataType*>(elem->channel_data)->~ChannelDataType(); - } - - static void StartTransportOp( - grpc_exec_ctx *exec_ctx, grpc_channel_element *elem, - grpc_transport_op *op) { - ChannelDataType* channel_data = (ChannelDataType*)elem->channel_data; - channel_data->StartTransportOp(exec_ctx, elem, op); - } }; struct FilterRecord { grpc_channel_stack_type stack_type; int priority; - std::function<bool(const grpc_channel_args*)> include_filter; + std::function<bool(const grpc_channel_args&)> include_filter; grpc_channel_filter filter; }; extern std::vector<FilterRecord>* channel_filters; @@ -171,7 +172,7 @@ void ChannelFilterPluginShutdown(); template<typename ChannelDataType, typename CallDataType> void RegisterChannelFilter( const char* name, grpc_channel_stack_type stack_type, int priority, - std::function<bool(const grpc_channel_args*)> include_filter) { + std::function<bool(const grpc_channel_args&)> include_filter) { // If we haven't been called before, initialize channel_filters and // call grpc_register_plugin(). if (internal::channel_filters == nullptr) { diff --git a/src/cpp/common/channel_filter.cc b/src/cpp/common/channel_filter.cc index b5e5e08976..77b2a26e8c 100644 --- a/src/cpp/common/channel_filter.cc +++ b/src/cpp/common/channel_filter.cc @@ -83,7 +83,7 @@ bool MaybeAddFilter(grpc_channel_stack_builder* builder, void* arg) { if (filter.include_filter != nullptr) { const grpc_channel_args *args = grpc_channel_stack_builder_get_channel_arguments(builder); - if (!filter.include_filter(args)) + if (!filter.include_filter(*args)) return true; } return grpc_channel_stack_builder_prepend_filter( diff --git a/test/cpp/end2end/filter_end2end_test.cc b/test/cpp/end2end/filter_end2end_test.cc index be6988c6ba..16151c21b8 100644 --- a/test/cpp/end2end/filter_end2end_test.cc +++ b/test/cpp/end2end/filter_end2end_test.cc @@ -95,9 +95,16 @@ int GetCounterValue() { } // namespace +class ChannelDataImpl : public ChannelData { + public: + explicit ChannelDataImpl(const grpc_channel_args& args) : ChannelData(args) {} + virtual ~ChannelDataImpl() {} +}; + class CallDataImpl : public CallData { public: - CallDataImpl() {} + explicit CallDataImpl(const ChannelDataImpl& channel_data) + : CallData(channel_data) {} virtual ~CallDataImpl() {} void StartTransportStreamOp( @@ -109,12 +116,6 @@ class CallDataImpl : public CallData { } }; -class ChannelDataImpl : public ChannelData { - public: - ChannelDataImpl() {} - virtual ~ChannelDataImpl() {} -}; - class FilterEnd2endTest : public ::testing::Test { protected: FilterEnd2endTest() : server_host_("localhost") {} |