diff options
Diffstat (limited to 'test/cpp/util/proto_reflection_descriptor_database.cc')
-rw-r--r-- | test/cpp/util/proto_reflection_descriptor_database.cc | 316 |
1 files changed, 316 insertions, 0 deletions
diff --git a/test/cpp/util/proto_reflection_descriptor_database.cc b/test/cpp/util/proto_reflection_descriptor_database.cc new file mode 100644 index 0000000000..25b720aee0 --- /dev/null +++ b/test/cpp/util/proto_reflection_descriptor_database.cc @@ -0,0 +1,316 @@ +/* + * + * Copyright 2016, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include "test/cpp/util/proto_reflection_descriptor_database.h" + +#include <vector> + +#include <grpc/support/log.h> + +using grpc::reflection::v1alpha::ServerReflection; +using grpc::reflection::v1alpha::ServerReflectionRequest; +using grpc::reflection::v1alpha::ServerReflectionResponse; +using grpc::reflection::v1alpha::ListServiceResponse; +using grpc::reflection::v1alpha::ErrorResponse; + +namespace grpc { + +ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase( + std::unique_ptr<ServerReflection::Stub> stub) + : stub_(std::move(stub)) {} + +ProtoReflectionDescriptorDatabase::ProtoReflectionDescriptorDatabase( + std::shared_ptr<grpc::Channel> channel) + : stub_(ServerReflection::NewStub(channel)) {} + +ProtoReflectionDescriptorDatabase::~ProtoReflectionDescriptorDatabase() {} + +bool ProtoReflectionDescriptorDatabase::FindFileByName( + const string& filename, google::protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileByName(filename, output)) { + return true; + } + + if (known_files_.find(filename) != known_files_.end()) { + return false; + } + + ServerReflectionRequest request; + request.set_file_by_filename(filename); + ServerReflectionResponse response; + + DoOneRequest(request, response); + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + gpr_log(GPR_INFO, "NOT_FOUND from server for FindFileByName(%s)", + filename.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindFileByName(%s)\n\tError code: %d\n" + "\tError Message: %s", + filename.c_str(), error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileByName(%s) response type\n" + "\tExpecting: %d\n\tReceived: %d", + filename.c_str(), + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + + return cached_db_.FindFileByName(filename, output); +} + +bool ProtoReflectionDescriptorDatabase::FindFileContainingSymbol( + const string& symbol_name, google::protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileContainingSymbol(symbol_name, output)) { + return true; + } + + if (missing_symbols_.find(symbol_name) != missing_symbols_.end()) { + return false; + } + + ServerReflectionRequest request; + request.set_file_containing_symbol(symbol_name); + ServerReflectionResponse response; + + DoOneRequest(request, response); + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + missing_symbols_.insert(symbol_name); + gpr_log(GPR_INFO, + "NOT_FOUND from server for FindFileContainingSymbol(%s)", + symbol_name.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindFileContainingSymbol(%s)\n" + "\tError code: %d\n\tError Message: %s", + symbol_name.c_str(), error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileContainingSymbol(%s) response type\n" + "\tExpecting: %d\n\tReceived: %d", + symbol_name.c_str(), + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + return cached_db_.FindFileContainingSymbol(symbol_name, output); +} + +bool ProtoReflectionDescriptorDatabase::FindFileContainingExtension( + const string& containing_type, int field_number, + google::protobuf::FileDescriptorProto* output) { + if (cached_db_.FindFileContainingExtension(containing_type, field_number, + output)) { + return true; + } + + if (missing_extensions_.find(containing_type) != missing_extensions_.end() && + missing_extensions_[containing_type].find(field_number) != + missing_extensions_[containing_type].end()) { + gpr_log(GPR_INFO, "nested map."); + return false; + } + + ServerReflectionRequest request; + request.mutable_file_containing_extension()->set_containing_type( + containing_type); + request.mutable_file_containing_extension()->set_extension_number( + field_number); + ServerReflectionResponse response; + + DoOneRequest(request, response); + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse) { + AddFileFromResponse(response.file_descriptor_response()); + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + if (missing_extensions_.find(containing_type) == + missing_extensions_.end()) { + missing_extensions_[containing_type] = {}; + } + missing_extensions_[containing_type].insert(field_number); + gpr_log(GPR_INFO, + "NOT_FOUND from server for FindFileContainingExtension(%s, %d)", + containing_type.c_str(), field_number); + } else { + gpr_log(GPR_INFO, + "Error on FindFileContainingExtension(%s, %d)\n" + "\tError code: %d\n\tError Message: %s", + containing_type.c_str(), field_number, error.error_code(), + error.error_message().c_str()); + } + } else { + gpr_log( + GPR_INFO, + "Error on FindFileContainingExtension(%s, %d) response type\n" + "\tExpecting: %d\n\tReceived: %d", + containing_type.c_str(), field_number, + ServerReflectionResponse::MessageResponseCase::kFileDescriptorResponse, + response.message_response_case()); + } + + return cached_db_.FindFileContainingExtension(containing_type, field_number, + output); +} + +bool ProtoReflectionDescriptorDatabase::FindAllExtensionNumbers( + const string& extendee_type, std::vector<int>* output) { + if (cached_extension_numbers_.find(extendee_type) != + cached_extension_numbers_.end()) { + *output = cached_extension_numbers_[extendee_type]; + return true; + } + + ServerReflectionRequest request; + request.set_all_extension_numbers_of_type(extendee_type); + ServerReflectionResponse response; + + DoOneRequest(request, response); + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase:: + kAllExtensionNumbersResponse) { + auto number = response.all_extension_numbers_response().extension_number(); + *output = std::vector<int>(number.begin(), number.end()); + cached_extension_numbers_[extendee_type] = *output; + return true; + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse error = response.error_response(); + if (error.error_code() == StatusCode::NOT_FOUND) { + gpr_log(GPR_INFO, "NOT_FOUND from server for FindAllExtensionNumbers(%s)", + extendee_type.c_str()); + } else { + gpr_log(GPR_INFO, + "Error on FindAllExtensionNumbersExtension(%s)\n" + "\tError code: %d\n\tError Message: %s", + extendee_type.c_str(), error.error_code(), + error.error_message().c_str()); + } + } + return false; +} + +bool ProtoReflectionDescriptorDatabase::GetServices( + std::vector<std::string>* output) { + ServerReflectionRequest request; + request.set_list_services(""); + ServerReflectionResponse response; + + DoOneRequest(request, response); + + if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kListServicesResponse) { + const ListServiceResponse ls_response = response.list_services_response(); + for (int i = 0; i < ls_response.service_size(); ++i) { + (*output).push_back(ls_response.service(i).name()); + } + return true; + } else if (response.message_response_case() == + ServerReflectionResponse::MessageResponseCase::kErrorResponse) { + const ErrorResponse error = response.error_response(); + gpr_log(GPR_INFO, + "Error on GetServices()\n\tError code: %d\n" + "\tError Message: %s", + error.error_code(), error.error_message().c_str()); + } else { + gpr_log( + GPR_INFO, + "Error on GetServices() response type\n\tExpecting: %d\n\tReceived: %d", + ServerReflectionResponse::MessageResponseCase::kListServicesResponse, + response.message_response_case()); + } + return false; +} + +const google::protobuf::FileDescriptorProto +ProtoReflectionDescriptorDatabase::ParseFileDescriptorProtoResponse( + const std::string& byte_fd_proto) { + google::protobuf::FileDescriptorProto file_desc_proto; + file_desc_proto.ParseFromString(byte_fd_proto); + return file_desc_proto; +} + +void ProtoReflectionDescriptorDatabase::AddFileFromResponse( + const grpc::reflection::v1alpha::FileDescriptorResponse& response) { + for (int i = 0; i < response.file_descriptor_proto_size(); ++i) { + const google::protobuf::FileDescriptorProto file_proto = + ParseFileDescriptorProtoResponse(response.file_descriptor_proto(i)); + if (known_files_.find(file_proto.name()) == known_files_.end()) { + known_files_.insert(file_proto.name()); + cached_db_.Add(file_proto); + } + } +} + +const std::shared_ptr<ProtoReflectionDescriptorDatabase::ClientStream> +ProtoReflectionDescriptorDatabase::GetStream() { + if (!stream_) { + stream_ = stub_->ServerReflectionInfo(&ctx_); + } + return stream_; +} + +void ProtoReflectionDescriptorDatabase::DoOneRequest( + const ServerReflectionRequest& request, + ServerReflectionResponse& response) { + stream_mutex_.lock(); + GetStream()->Write(request); + GetStream()->Read(&response); + stream_mutex_.unlock(); +} + +} // namespace grpc |