aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark D. Roth <roth@google.com>2016-06-16 10:39:39 -0700
committerGravatar Mark D. Roth <roth@google.com>2016-06-16 10:39:39 -0700
commitc008b33c18f1d2c49ee1b6683c2d13a85e2c2432 (patch)
tree7ae8188383c268aee1cdeb4fb47ca2110f1542ec
parent8a1a5976c7eced718e2385b7f36046d6e73017af (diff)
Pass channel args to ChannelData ctor and ChannelData to CallData ctor.
-rw-r--r--include/grpc++/channel_filter.h79
-rw-r--r--src/cpp/common/channel_filter.cc2
-rw-r--r--test/cpp/end2end/filter_end2end_test.cc15
3 files changed, 49 insertions, 47 deletions
diff --git a/include/grpc++/channel_filter.h b/include/grpc++/channel_filter.h
index b37d986a5a..8731a5e965 100644
--- a/include/grpc++/channel_filter.h
+++ b/include/grpc++/channel_filter.h
@@ -53,6 +53,20 @@
namespace grpc {
+// Represents channel data.
+// Note: Must be copyable.
+class ChannelData {
+ public:
+ virtual ~ChannelData() {}
+
+ virtual void StartTransportOp(
+ grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
+ grpc_transport_op *op);
+
+ protected:
+ explicit ChannelData(const grpc_channel_args&) {}
+};
+
// Represents call data.
// Note: Must be copyable.
class CallData {
@@ -70,21 +84,7 @@ class CallData {
virtual char* GetPeer(grpc_exec_ctx *exec_ctx, grpc_call_element *elem);
protected:
- CallData() {}
-};
-
-// Represents channel data.
-// Note: Must be copyable.
-class ChannelData {
- public:
- virtual ~ChannelData() {}
-
- virtual void StartTransportOp(
- grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
- grpc_transport_op *op);
-
- protected:
- ChannelData() {}
+ explicit CallData(const ChannelData&) {}
};
namespace internal {
@@ -93,13 +93,35 @@ namespace internal {
template<typename ChannelDataType, typename CallDataType>
class ChannelFilter {
public:
+ static const size_t channel_data_size = sizeof(ChannelDataType);
+
+ static void InitChannelElement(
+ grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
+ grpc_channel_element_args *args) {
+ // Construct the object in the already-allocated memory.
+ new (elem->channel_data) ChannelDataType(*args->channel_args);
+ }
+
+ static void DestroyChannelElement(
+ grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) {
+ reinterpret_cast<ChannelDataType*>(elem->channel_data)->~ChannelDataType();
+ }
+
+ static void StartTransportOp(
+ grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
+ grpc_transport_op *op) {
+ ChannelDataType* channel_data = (ChannelDataType*)elem->channel_data;
+ channel_data->StartTransportOp(exec_ctx, elem, op);
+ }
+
static const size_t call_data_size = sizeof(CallDataType);
static void InitCallElement(
grpc_exec_ctx *exec_ctx, grpc_call_element *elem,
grpc_call_element_args *args) {
+ const ChannelDataType& channel_data = *(ChannelDataType*)elem->channel_data;
// Construct the object in the already-allocated memory.
- new (elem->call_data) CallDataType();
+ new (elem->call_data) CallDataType(channel_data);
}
static void DestroyCallElement(
@@ -127,33 +149,12 @@ class ChannelFilter {
CallDataType* call_data = (CallDataType*)elem->call_data;
return call_data->GetPeer(exec_ctx, elem);
}
-
- static const size_t channel_data_size = sizeof(ChannelDataType);
-
- static void InitChannelElement(
- grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
- grpc_channel_element_args *args) {
- // Construct the object in the already-allocated memory.
- new (elem->channel_data) ChannelDataType();
- }
-
- static void DestroyChannelElement(
- grpc_exec_ctx *exec_ctx, grpc_channel_element *elem) {
- reinterpret_cast<ChannelDataType*>(elem->channel_data)->~ChannelDataType();
- }
-
- static void StartTransportOp(
- grpc_exec_ctx *exec_ctx, grpc_channel_element *elem,
- grpc_transport_op *op) {
- ChannelDataType* channel_data = (ChannelDataType*)elem->channel_data;
- channel_data->StartTransportOp(exec_ctx, elem, op);
- }
};
struct FilterRecord {
grpc_channel_stack_type stack_type;
int priority;
- std::function<bool(const grpc_channel_args*)> include_filter;
+ std::function<bool(const grpc_channel_args&)> include_filter;
grpc_channel_filter filter;
};
extern std::vector<FilterRecord>* channel_filters;
@@ -171,7 +172,7 @@ void ChannelFilterPluginShutdown();
template<typename ChannelDataType, typename CallDataType>
void RegisterChannelFilter(
const char* name, grpc_channel_stack_type stack_type, int priority,
- std::function<bool(const grpc_channel_args*)> include_filter) {
+ std::function<bool(const grpc_channel_args&)> include_filter) {
// If we haven't been called before, initialize channel_filters and
// call grpc_register_plugin().
if (internal::channel_filters == nullptr) {
diff --git a/src/cpp/common/channel_filter.cc b/src/cpp/common/channel_filter.cc
index b5e5e08976..77b2a26e8c 100644
--- a/src/cpp/common/channel_filter.cc
+++ b/src/cpp/common/channel_filter.cc
@@ -83,7 +83,7 @@ bool MaybeAddFilter(grpc_channel_stack_builder* builder, void* arg) {
if (filter.include_filter != nullptr) {
const grpc_channel_args *args =
grpc_channel_stack_builder_get_channel_arguments(builder);
- if (!filter.include_filter(args))
+ if (!filter.include_filter(*args))
return true;
}
return grpc_channel_stack_builder_prepend_filter(
diff --git a/test/cpp/end2end/filter_end2end_test.cc b/test/cpp/end2end/filter_end2end_test.cc
index be6988c6ba..16151c21b8 100644
--- a/test/cpp/end2end/filter_end2end_test.cc
+++ b/test/cpp/end2end/filter_end2end_test.cc
@@ -95,9 +95,16 @@ int GetCounterValue() {
} // namespace
+class ChannelDataImpl : public ChannelData {
+ public:
+ explicit ChannelDataImpl(const grpc_channel_args& args) : ChannelData(args) {}
+ virtual ~ChannelDataImpl() {}
+};
+
class CallDataImpl : public CallData {
public:
- CallDataImpl() {}
+ explicit CallDataImpl(const ChannelDataImpl& channel_data)
+ : CallData(channel_data) {}
virtual ~CallDataImpl() {}
void StartTransportStreamOp(
@@ -109,12 +116,6 @@ class CallDataImpl : public CallData {
}
};
-class ChannelDataImpl : public ChannelData {
- public:
- ChannelDataImpl() {}
- virtual ~ChannelDataImpl() {}
-};
-
class FilterEnd2endTest : public ::testing::Test {
protected:
FilterEnd2endTest() : server_host_("localhost") {}