aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--include/grpcpp/impl/codegen/call_op_set.h23
-rw-r--r--include/grpcpp/impl/codegen/interceptor.h9
-rw-r--r--include/grpcpp/impl/codegen/interceptor_common.h28
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc83
4 files changed, 139 insertions, 4 deletions
diff --git a/include/grpcpp/impl/codegen/call_op_set.h b/include/grpcpp/impl/codegen/call_op_set.h
index b2100c68b7..1c0ccbab52 100644
--- a/include/grpcpp/impl/codegen/call_op_set.h
+++ b/include/grpcpp/impl/codegen/call_op_set.h
@@ -314,21 +314,37 @@ class CallOpSendMessage {
// Flags are per-message: clear them after use.
write_options_.Clear();
}
- void FinishOp(bool* status) { send_buf_.Clear(); }
+ void FinishOp(bool* status) {
+ if (!send_buf_.Valid()) {
+ return;
+ }
+ if (hijacked_ && failed_send_) {
+ // Hijacking interceptor failed this Op
+ *status = false;
+ } else if (!*status) {
+ // This Op was passed down to core and the Op failed
+ failed_send_ = true;
+ }
+ }
void SetInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
if (!send_buf_.Valid()) return;
interceptor_methods->AddInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE);
- interceptor_methods->SetSendMessage(&send_buf_);
+ interceptor_methods->SetSendMessage(&send_buf_, &failed_send_);
}
void SetFinishInterceptionHookPoint(
InterceptorBatchMethodsImpl* interceptor_methods) {
+ if (send_buf_.Valid()) {
+ interceptor_methods->AddInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_SEND_MESSAGE);
+ }
+ send_buf_.Clear();
// The contents of the SendMessage value that was previously set
// has had its references stolen by core's operations
- interceptor_methods->SetSendMessage(nullptr);
+ interceptor_methods->SetSendMessage(nullptr, &failed_send_);
}
void SetHijackingState(InterceptorBatchMethodsImpl* interceptor_methods) {
@@ -337,6 +353,7 @@ class CallOpSendMessage {
private:
bool hijacked_ = false;
+ bool failed_send_ = false;
ByteBuffer send_buf_;
WriteOptions write_options_;
};
diff --git a/include/grpcpp/impl/codegen/interceptor.h b/include/grpcpp/impl/codegen/interceptor.h
index 46175cd73b..58b4f36d95 100644
--- a/include/grpcpp/impl/codegen/interceptor.h
+++ b/include/grpcpp/impl/codegen/interceptor.h
@@ -49,6 +49,7 @@ enum class InterceptionHookPoints {
/// The first two in this list are for clients and servers
PRE_SEND_INITIAL_METADATA,
PRE_SEND_MESSAGE,
+ POST_SEND_MESSAGE,
PRE_SEND_STATUS, // server only
PRE_SEND_CLOSE, // client only: WritesDone for stream; after write in unary
/// The following three are for hijacked clients only and can only be
@@ -111,6 +112,10 @@ class InterceptorBatchMethods {
/// A return value of nullptr indicates that this ByteBuffer is not valid.
virtual ByteBuffer* GetSendMessage() = 0;
+ /// Checks whether the SEND MESSAGE op succeeded. Valid for POST_SEND_MESSAGE
+ /// interceptions.
+ virtual bool GetSendMessageStatus() = 0;
+
/// Returns a modifiable multimap of the initial metadata to be sent. Valid
/// for PRE_SEND_INITIAL_METADATA interceptions. A value of nullptr indicates
/// that this field is not valid.
@@ -156,6 +161,10 @@ class InterceptorBatchMethods {
/// started from interceptors without infinite regress through the interceptor
/// list.
virtual std::unique_ptr<ChannelInterface> GetInterceptedChannel() = 0;
+
+ // On a hijacked RPC/ to-be hijacked RPC, this can be called to fail a SEND
+ // MESSAGE op
+ virtual void FailHijackedSendMessage() = 0;
};
/// Interface for an interceptor. Interceptor authors must create a class
diff --git a/include/grpcpp/impl/codegen/interceptor_common.h b/include/grpcpp/impl/codegen/interceptor_common.h
index d0aa23cb0a..b01706af8d 100644
--- a/include/grpcpp/impl/codegen/interceptor_common.h
+++ b/include/grpcpp/impl/codegen/interceptor_common.h
@@ -81,6 +81,8 @@ class InterceptorBatchMethodsImpl
ByteBuffer* GetSendMessage() override { return send_message_; }
+ bool GetSendMessageStatus() override { return !*fail_send_message_; }
+
std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
return send_initial_metadata_;
}
@@ -110,12 +112,21 @@ class InterceptorBatchMethodsImpl
Status* GetRecvStatus() override { return recv_status_; }
+ void FailHijackedSendMessage() override {
+ GPR_CODEGEN_ASSERT(hooks_[static_cast<size_t>(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)]);
+ *fail_send_message_ = true;
+ }
+
std::multimap<grpc::string_ref, grpc::string_ref>* GetRecvTrailingMetadata()
override {
return recv_trailing_metadata_->map();
}
- void SetSendMessage(ByteBuffer* buf) { send_message_ = buf; }
+ void SetSendMessage(ByteBuffer* buf, bool* fail_send_message) {
+ send_message_ = buf;
+ fail_send_message_ = fail_send_message;
+ }
void SetSendInitialMetadata(
std::multimap<grpc::string, grpc::string>* metadata) {
@@ -334,6 +345,7 @@ class InterceptorBatchMethodsImpl
std::function<void(void)> callback_;
ByteBuffer* send_message_ = nullptr;
+ bool* fail_send_message_ = nullptr;
std::multimap<grpc::string, grpc::string>* send_initial_metadata_;
@@ -386,6 +398,14 @@ class CancelInterceptorBatchMethods
return nullptr;
}
+ bool GetSendMessageStatus() override {
+ GPR_CODEGEN_ASSERT(
+ false &&
+ "It is illegal to call GetSendMessageStatus on a method which "
+ "has a Cancel notification");
+ return false;
+ }
+
std::multimap<grpc::string, grpc::string>* GetSendInitialMetadata() override {
GPR_CODEGEN_ASSERT(false &&
"It is illegal to call GetSendInitialMetadata on a "
@@ -451,6 +471,12 @@ class CancelInterceptorBatchMethods
"method which has a Cancel notification");
return std::unique_ptr<ChannelInterface>(nullptr);
}
+
+ void FailHijackedSendMessage() override {
+ GPR_CODEGEN_ASSERT(false &&
+ "It is illegal to call FailHijackedSendMessage on a "
+ "method which has a Cancel notification");
+ }
};
} // namespace internal
} // namespace grpc
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
index 8abf4eb3f4..596f20a542 100644
--- a/test/cpp/end2end/client_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -270,6 +270,60 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
+class ClientStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedSendMessage();
+ }
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
+ EXPECT_FALSE(got_failed_send_);
+ got_failed_send_ = !methods->GetSendMessageStatus();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedSend() { return got_failed_send_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ int count_ = 0;
+ static bool got_failed_send_;
+};
+
+bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
+
+class ClientStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ClientStreamingRpcHijackingInterceptor(info);
+ }
+};
+
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@@ -546,6 +600,35 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
+ ChannelArguments args;
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
+ new ClientStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ req.mutable_param()->set_echo_metadata(true);
+ req.set_message("Hello");
+ string expected_resp = "";
+ auto writer = stub->RequestStream(&ctx, &resp);
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(writer->Write(req));
+ expected_resp += "Hello";
+ }
+ // The interceptor will reject the 11th message
+ writer->Write(req);
+ Status s = writer->Finish();
+ EXPECT_EQ(s.ok(), false);
+ EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
+}
+
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();