diff options
Diffstat (limited to 'test/core/tsi/alts/handshaker')
8 files changed, 2383 insertions, 0 deletions
diff --git a/test/core/tsi/alts/handshaker/BUILD b/test/core/tsi/alts/handshaker/BUILD new file mode 100644 index 0000000000..fc2c395bdf --- /dev/null +++ b/test/core/tsi/alts/handshaker/BUILD @@ -0,0 +1,86 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("//bazel:grpc_build_system.bzl", "grpc_cc_library", "grpc_cc_test", "grpc_package") + +licenses(["notice"]) # Apache v2 + +grpc_package(name = "handshaker") + +grpc_cc_library( + name = "alts_handshaker_service_api_test_lib", + srcs = ["alts_handshaker_service_api_test_lib.cc"], + hdrs = ["alts_handshaker_service_api_test_lib.h"], + deps = [ + "//:alts_util", + "//:grpc", + ], +) + +grpc_cc_test( + name = "alts_handshaker_client_test", + srcs = ["alts_handshaker_client_test.cc"], + language = "C++", + deps = [ + ":alts_handshaker_service_api_test_lib", + "//:tsi", + "//:tsi_interface", + "//:grpc", + ], +) + +grpc_cc_test( + name = "alts_handshaker_service_api_test", + srcs = ["alts_handshaker_service_api_test.cc"], + language = "C++", + deps = [ + ":alts_handshaker_service_api_test_lib", + "//:grpc", + ], +) + +grpc_cc_test( + name = "alts_tsi_handshaker_test", + srcs = ["alts_tsi_handshaker_test.cc"], + language = "C++", + deps = [ + ":alts_handshaker_service_api_test_lib", + "//:gpr", + "//:gpr_base", + "//:grpc", + "//:tsi", + ], +) + +grpc_cc_test( + name = "alts_tsi_utils_test", + srcs = ["alts_tsi_utils_test.cc"], + language = "C++", + deps = [ + ":alts_handshaker_service_api_test_lib", + "//:grpc", + "//:tsi", + ], +) + +grpc_cc_test( + name = "transport_security_common_api_test", + srcs = ["transport_security_common_api_test.cc"], + language = "C++", + deps = [ + "//:alts_util", + "//:grpc", + ], +) + diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc new file mode 100644 index 0000000000..7072be6e3a --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_handshaker_client_test.cc @@ -0,0 +1,412 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <grpc/grpc.h> + +#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_event.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/core/tsi/transport_security.h" +#include "src/core/tsi/transport_security_interface.h" +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" + +#define ALTS_HANDSHAKER_CLIENT_TEST_OUT_FRAME "Hello Google" +#define ALTS_HANDSHAKER_CLIENT_TEST_HANDSHAKER_SERVICE_URL "lame" +#define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME "bigtable.google.api.com" +#define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1 "A@google.com" +#define ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2 "B@google.com" + +const size_t kHandshakerClientOpNum = 4; +const size_t kMaxRpcVersionMajor = 3; +const size_t kMaxRpcVersionMinor = 2; +const size_t kMinRpcVersionMajor = 2; +const size_t kMinRpcVersionMinor = 1; + +using grpc_core::internal::alts_handshaker_client_set_grpc_caller_for_testing; + +typedef struct alts_handshaker_client_test_config { + grpc_channel* channel; + grpc_completion_queue* cq; + alts_handshaker_client* client; + grpc_slice out_frame; +} alts_handshaker_client_test_config; + +static alts_tsi_event* alts_tsi_event_create_for_testing(bool is_client) { + alts_tsi_event* e = static_cast<alts_tsi_event*>(gpr_zalloc(sizeof(*e))); + grpc_metadata_array_init(&e->initial_metadata); + grpc_metadata_array_init(&e->trailing_metadata); + e->options = is_client ? grpc_alts_credentials_client_options_create() + : grpc_alts_credentials_server_options_create(); + if (is_client) { + grpc_alts_credentials_client_options_add_target_service_account( + reinterpret_cast<grpc_alts_credentials_client_options*>(e->options), + ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1); + grpc_alts_credentials_client_options_add_target_service_account( + reinterpret_cast<grpc_alts_credentials_client_options*>(e->options), + ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2); + } + grpc_gcp_rpc_protocol_versions* versions = &e->options->rpc_versions; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max( + versions, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min( + versions, kMinRpcVersionMajor, kMinRpcVersionMinor)); + e->target_name = + grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME); + return e; +} + +static void validate_rpc_protocol_versions( + grpc_gcp_rpc_protocol_versions* versions) { + GPR_ASSERT(versions != nullptr); + GPR_ASSERT(versions->max_rpc_version.major == kMaxRpcVersionMajor); + GPR_ASSERT(versions->max_rpc_version.minor == kMaxRpcVersionMinor); + GPR_ASSERT(versions->min_rpc_version.major == kMinRpcVersionMajor); + GPR_ASSERT(versions->min_rpc_version.minor == kMinRpcVersionMinor); +} + +static void validate_target_identities( + const repeated_field* target_identity_head) { + grpc_gcp_identity* target_identity1 = static_cast<grpc_gcp_identity*>( + const_cast<void*>(target_identity_head->next->data)); + grpc_gcp_identity* target_identity2 = static_cast<grpc_gcp_identity*>( + const_cast<void*>(target_identity_head->data)); + grpc_slice* service_account1 = + static_cast<grpc_slice*>(target_identity1->service_account.arg); + grpc_slice* service_account2 = + static_cast<grpc_slice*>(target_identity2->service_account.arg); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*service_account1), + ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1, + GRPC_SLICE_LENGTH(*service_account1)) == 0); + GPR_ASSERT(strlen(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT1) == + GRPC_SLICE_LENGTH(*service_account1)); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*service_account2), + ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2, + GRPC_SLICE_LENGTH(*service_account2)) == 0); + GPR_ASSERT(strlen(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_SERVICE_ACCOUNT2) == + GRPC_SLICE_LENGTH(*service_account2)); +} + +/** + * Validate if grpc operation data is correctly populated with the fields of + * ALTS TSI event. + */ +static bool validate_op(alts_tsi_event* event, const grpc_op* op, size_t nops, + bool is_start) { + GPR_ASSERT(event != nullptr && op != nullptr && nops != 0); + bool ok = true; + grpc_op* start_op = const_cast<grpc_op*>(op); + if (is_start) { + ok &= (op->op == GRPC_OP_SEND_INITIAL_METADATA); + ok &= (op->data.send_initial_metadata.count == 0); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + + ok &= (op->op == GRPC_OP_RECV_INITIAL_METADATA); + ok &= (op->data.recv_initial_metadata.recv_initial_metadata == + &event->initial_metadata); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + } + ok &= (op->op == GRPC_OP_SEND_MESSAGE); + ok &= (op->data.send_message.send_message == event->send_buffer); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + + ok &= (op->op == GRPC_OP_RECV_MESSAGE); + ok &= (op->data.recv_message.recv_message == &event->recv_buffer); + op++; + GPR_ASSERT((size_t)(op - start_op) <= kHandshakerClientOpNum); + + return ok; +} + +static grpc_gcp_handshaker_req* deserialize_handshaker_req( + grpc_gcp_handshaker_req_type type, grpc_byte_buffer* buffer) { + GPR_ASSERT(buffer != nullptr); + grpc_gcp_handshaker_req* req = grpc_gcp_handshaker_decoded_req_create(type); + grpc_byte_buffer_reader bbr; + GPR_ASSERT(grpc_byte_buffer_reader_init(&bbr, buffer)); + grpc_slice slice = grpc_byte_buffer_reader_readall(&bbr); + GPR_ASSERT(grpc_gcp_handshaker_req_decode(slice, req)); + grpc_slice_unref(slice); + grpc_byte_buffer_reader_destroy(&bbr); + return req; +} + +/** + * A mock grpc_caller used to check if client_start, server_start, and next + * operations correctly handle invalid arguments. It should not be called. + */ +static grpc_call_error check_must_not_be_called(grpc_call* call, + const grpc_op* ops, size_t nops, + void* tag) { + GPR_ASSERT(0); +} + +/** + * A mock grpc_caller used to check correct execution of client_start operation. + * It checks if the client_start handshaker request is populated with correct + * handshake_security_protocol, application_protocol, and record_protocol, and + * op is correctly populated. + */ +static grpc_call_error check_client_start_success(grpc_call* call, + const grpc_op* op, + size_t nops, void* tag) { + alts_tsi_event* event = static_cast<alts_tsi_event*>(tag); + grpc_gcp_handshaker_req* req = + deserialize_handshaker_req(CLIENT_START_REQ, event->send_buffer); + GPR_ASSERT(req->client_start.handshake_security_protocol == + grpc_gcp_HandshakeProtocol_ALTS); + const void* data = (static_cast<repeated_field*>( + req->client_start.application_protocols.arg)) + ->data; + GPR_ASSERT(data != nullptr); + grpc_slice* application_protocol = (grpc_slice*)data; + data = (static_cast<repeated_field*>(req->client_start.record_protocols.arg)) + ->data; + grpc_slice* record_protocol = (grpc_slice*)data; + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*application_protocol), + ALTS_APPLICATION_PROTOCOL, + GRPC_SLICE_LENGTH(*application_protocol)) == 0); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*record_protocol), + ALTS_RECORD_PROTOCOL, + GRPC_SLICE_LENGTH(*record_protocol)) == 0); + validate_rpc_protocol_versions(&req->client_start.rpc_versions); + validate_target_identities( + static_cast<repeated_field*>(req->client_start.target_identities.arg)); + grpc_slice* target_name = + static_cast<grpc_slice*>(req->client_start.target_name.arg); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*target_name), + ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME, + GRPC_SLICE_LENGTH(*target_name)) == 0); + GPR_ASSERT(GRPC_SLICE_LENGTH(*target_name) == + strlen(ALTS_HANDSHAKER_CLIENT_TEST_TARGET_NAME)); + GPR_ASSERT(validate_op(event, op, nops, true /* is_start */)); + grpc_gcp_handshaker_req_destroy(req); + return GRPC_CALL_OK; +} + +/** + * A mock grpc_caller used to check correct execution of server_start operation. + * It checks if the server_start handshaker request is populated with correct + * handshake_security_protocol, application_protocol, and record_protocol, and + * op is correctly populated. + */ +static grpc_call_error check_server_start_success(grpc_call* call, + const grpc_op* op, + size_t nops, void* tag) { + alts_tsi_event* event = static_cast<alts_tsi_event*>(tag); + grpc_gcp_handshaker_req* req = + deserialize_handshaker_req(SERVER_START_REQ, event->send_buffer); + const void* data = (static_cast<repeated_field*>( + req->server_start.application_protocols.arg)) + ->data; + GPR_ASSERT(data != nullptr); + grpc_slice* application_protocol = (grpc_slice*)data; + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*application_protocol), + ALTS_APPLICATION_PROTOCOL, + GRPC_SLICE_LENGTH(*application_protocol)) == 0); + GPR_ASSERT(req->server_start.handshake_parameters_count == 1); + GPR_ASSERT(req->server_start.handshake_parameters[0].key == + grpc_gcp_HandshakeProtocol_ALTS); + data = (static_cast<repeated_field*>(req->server_start.handshake_parameters[0] + .value.record_protocols.arg)) + ->data; + GPR_ASSERT(data != nullptr); + grpc_slice* record_protocol = (grpc_slice*)data; + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*record_protocol), + ALTS_RECORD_PROTOCOL, + GRPC_SLICE_LENGTH(*record_protocol)) == 0); + validate_rpc_protocol_versions(&req->server_start.rpc_versions); + GPR_ASSERT(validate_op(event, op, nops, true /* is_start */)); + grpc_gcp_handshaker_req_destroy(req); + return GRPC_CALL_OK; +} + +/** + * A mock grpc_caller used to check correct execution of next operation. It + * checks if the next handshaker request is populated with correct information, + * and op is correctly populated. + */ +static grpc_call_error check_next_success(grpc_call* call, const grpc_op* op, + size_t nops, void* tag) { + alts_tsi_event* event = static_cast<alts_tsi_event*>(tag); + grpc_gcp_handshaker_req* req = + deserialize_handshaker_req(NEXT_REQ, event->send_buffer); + grpc_slice* in_bytes = static_cast<grpc_slice*>(req->next.in_bytes.arg); + GPR_ASSERT(in_bytes != nullptr); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*in_bytes), + ALTS_HANDSHAKER_CLIENT_TEST_OUT_FRAME, + GRPC_SLICE_LENGTH(*in_bytes)) == 0); + GPR_ASSERT(validate_op(event, op, nops, false /* is_start */)); + grpc_gcp_handshaker_req_destroy(req); + return GRPC_CALL_OK; +} +/** + * A mock grpc_caller used to check if client_start, server_start, and next + * operations correctly handle the situation when the grpc call made to the + * handshaker service fails. + */ +static grpc_call_error check_grpc_call_failure(grpc_call* call, + const grpc_op* op, size_t nops, + void* tag) { + return GRPC_CALL_ERROR; +} + +static alts_handshaker_client_test_config* create_config() { + alts_handshaker_client_test_config* config = + static_cast<alts_handshaker_client_test_config*>( + gpr_zalloc(sizeof(*config))); + config->channel = grpc_insecure_channel_create( + ALTS_HANDSHAKER_CLIENT_TEST_HANDSHAKER_SERVICE_URL, nullptr, nullptr); + config->cq = grpc_completion_queue_create_for_next(nullptr); + config->client = alts_grpc_handshaker_client_create( + config->channel, config->cq, + ALTS_HANDSHAKER_CLIENT_TEST_HANDSHAKER_SERVICE_URL); + GPR_ASSERT(config->client != nullptr); + config->out_frame = + grpc_slice_from_static_string(ALTS_HANDSHAKER_CLIENT_TEST_OUT_FRAME); + return config; +} + +static void destroy_config(alts_handshaker_client_test_config* config) { + if (config == nullptr) { + return; + } + grpc_completion_queue_destroy(config->cq); + grpc_channel_destroy(config->channel); + alts_handshaker_client_destroy(config->client); + grpc_slice_unref(config->out_frame); + gpr_free(config); +} + +static void schedule_request_invalid_arg_test() { + /* Initialization. */ + alts_handshaker_client_test_config* config = create_config(); + alts_tsi_event* event = nullptr; + + /* Tests. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->client, + check_must_not_be_called); + event = alts_tsi_event_create_for_testing(true /* is_client */); + /* Check client_start. */ + GPR_ASSERT(alts_handshaker_client_start_client(nullptr, event) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(alts_handshaker_client_start_client(config->client, nullptr) == + TSI_INVALID_ARGUMENT); + + /* Check server_start. */ + GPR_ASSERT(alts_handshaker_client_start_server( + config->client, event, nullptr) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(alts_handshaker_client_start_server(config->client, nullptr, + &config->out_frame) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(alts_handshaker_client_start_server( + nullptr, event, &config->out_frame) == TSI_INVALID_ARGUMENT); + + /* Check next. */ + GPR_ASSERT(alts_handshaker_client_next(config->client, event, nullptr) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(alts_handshaker_client_next(config->client, nullptr, + &config->out_frame) == + TSI_INVALID_ARGUMENT); + GPR_ASSERT(alts_handshaker_client_next(nullptr, event, &config->out_frame) == + TSI_INVALID_ARGUMENT); + + /* Cleanup. */ + alts_tsi_event_destroy(event); + destroy_config(config); +} + +static void schedule_request_success_test() { + /* Initialization. */ + alts_handshaker_client_test_config* config = create_config(); + alts_tsi_event* event = nullptr; + + /* Check client_start success. */ + alts_handshaker_client_set_grpc_caller_for_testing( + config->client, check_client_start_success); + event = alts_tsi_event_create_for_testing(true /* is_client. */); + GPR_ASSERT(alts_handshaker_client_start_client(config->client, event) == + TSI_OK); + alts_tsi_event_destroy(event); + + /* Check server_start success. */ + alts_handshaker_client_set_grpc_caller_for_testing( + config->client, check_server_start_success); + event = alts_tsi_event_create_for_testing(false /* is_client. */); + GPR_ASSERT(alts_handshaker_client_start_server(config->client, event, + &config->out_frame) == TSI_OK); + alts_tsi_event_destroy(event); + + /* Check next success. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->client, + check_next_success); + event = alts_tsi_event_create_for_testing(true /* is_client. */); + GPR_ASSERT(alts_handshaker_client_next(config->client, event, + &config->out_frame) == TSI_OK); + alts_tsi_event_destroy(event); + + /* Cleanup. */ + destroy_config(config); +} + +static void schedule_request_grpc_call_failure_test() { + /* Initialization. */ + alts_handshaker_client_test_config* config = create_config(); + alts_tsi_event* event = nullptr; + + /* Check client_start failure. */ + alts_handshaker_client_set_grpc_caller_for_testing(config->client, + check_grpc_call_failure); + event = alts_tsi_event_create_for_testing(true /* is_client. */); + GPR_ASSERT(alts_handshaker_client_start_client(config->client, event) == + TSI_INTERNAL_ERROR); + alts_tsi_event_destroy(event); + + /* Check server_start failure. */ + event = alts_tsi_event_create_for_testing(false /* is_client. */); + GPR_ASSERT(alts_handshaker_client_start_server(config->client, event, + &config->out_frame) == + TSI_INTERNAL_ERROR); + alts_tsi_event_destroy(event); + + /* Check next failure. */ + event = alts_tsi_event_create_for_testing(true /* is_cleint. */); + GPR_ASSERT( + alts_handshaker_client_next(config->client, event, &config->out_frame) == + TSI_INTERNAL_ERROR); + alts_tsi_event_destroy(event); + + /* Cleanup. */ + destroy_config(config); +} + +int main(int argc, char** argv) { + /* Initialization. */ + grpc_init(); + + /* Tests. */ + schedule_request_invalid_arg_test(); + schedule_request_success_test(); + schedule_request_grpc_call_failure_test(); + + /* Cleanup. */ + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test.cc b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test.cc new file mode 100644 index 0000000000..3506264f52 --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test.cc @@ -0,0 +1,149 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> + +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" + +int main(int argc, char** argv) { + const char in_bytes[] = "HELLO GOOGLE!"; + const char out_frames[] = "HELLO WORLD!"; + const char key_data[] = "THIS IS KEY DATA."; + const char details[] = "DETAILS NEED TO BE POPULATED"; + const uint32_t max_rpc_version_major = 3; + const uint32_t max_rpc_version_minor = 2; + const uint32_t min_rpc_version_major = 2; + const uint32_t min_rpc_version_minor = 1; + + /* handshaker_req_next. */ + grpc_gcp_handshaker_req* req = grpc_gcp_handshaker_req_create(NEXT_REQ); + grpc_gcp_handshaker_req* decoded_req = + grpc_gcp_handshaker_decoded_req_create(NEXT_REQ); + GPR_ASSERT( + grpc_gcp_handshaker_req_set_in_bytes(req, in_bytes, strlen(in_bytes))); + grpc_slice encoded_req; + GPR_ASSERT(grpc_gcp_handshaker_req_encode(req, &encoded_req)); + GPR_ASSERT(grpc_gcp_handshaker_req_decode(encoded_req, decoded_req)); + GPR_ASSERT(grpc_gcp_handshaker_req_equals(req, decoded_req)); + grpc_gcp_handshaker_req_destroy(req); + grpc_gcp_handshaker_req_destroy(decoded_req); + grpc_slice_unref(encoded_req); + + /* handshaker_req_client_start. */ + req = grpc_gcp_handshaker_req_create(CLIENT_START_REQ); + decoded_req = grpc_gcp_handshaker_decoded_req_create(CLIENT_START_REQ); + GPR_ASSERT(grpc_gcp_handshaker_req_set_handshake_protocol( + req, grpc_gcp_HandshakeProtocol_TLS)); + GPR_ASSERT(grpc_gcp_handshaker_req_set_local_identity_hostname( + req, "www.google.com")); + GPR_ASSERT(grpc_gcp_handshaker_req_set_local_endpoint( + req, "2001:db8::8:800:200C:417a", 9876, grpc_gcp_NetworkProtocol_TCP)); + GPR_ASSERT(grpc_gcp_handshaker_req_set_remote_endpoint( + req, "2001:db8::bac5::fed0:84a2", 1234, grpc_gcp_NetworkProtocol_TCP)); + GPR_ASSERT(grpc_gcp_handshaker_req_add_application_protocol(req, "grpc")); + GPR_ASSERT(grpc_gcp_handshaker_req_add_application_protocol(req, "http2")); + GPR_ASSERT( + grpc_gcp_handshaker_req_add_record_protocol(req, "ALTSRP_GCM_AES256")); + GPR_ASSERT( + grpc_gcp_handshaker_req_add_record_protocol(req, "ALTSRP_GCM_AES384")); + GPR_ASSERT(grpc_gcp_handshaker_req_add_target_identity_service_account( + req, "foo@google.com")); + GPR_ASSERT(grpc_gcp_handshaker_req_set_target_name( + req, "google.example.library.service")); + GPR_ASSERT(grpc_gcp_handshaker_req_set_rpc_versions( + req, max_rpc_version_major, max_rpc_version_minor, min_rpc_version_major, + min_rpc_version_minor)); + GPR_ASSERT(grpc_gcp_handshaker_req_encode(req, &encoded_req)); + GPR_ASSERT(grpc_gcp_handshaker_req_decode(encoded_req, decoded_req)); + GPR_ASSERT(grpc_gcp_handshaker_req_equals(req, decoded_req)); + grpc_gcp_handshaker_req_destroy(req); + grpc_gcp_handshaker_req_destroy(decoded_req); + grpc_slice_unref(encoded_req); + + /* handshaker_req_server_start. */ + req = grpc_gcp_handshaker_req_create(SERVER_START_REQ); + decoded_req = grpc_gcp_handshaker_decoded_req_create(SERVER_START_REQ); + GPR_ASSERT(grpc_gcp_handshaker_req_add_application_protocol(req, "grpc")); + GPR_ASSERT(grpc_gcp_handshaker_req_add_application_protocol(req, "http2")); + GPR_ASSERT(grpc_gcp_handshaker_req_set_local_endpoint( + req, "2001:db8::8:800:200C:417a", 9876, grpc_gcp_NetworkProtocol_TCP)); + GPR_ASSERT(grpc_gcp_handshaker_req_set_remote_endpoint( + req, "2001:db8::bac5::fed0:84a2", 1234, grpc_gcp_NetworkProtocol_UDP)); + GPR_ASSERT( + grpc_gcp_handshaker_req_set_in_bytes(req, in_bytes, strlen(in_bytes))); + GPR_ASSERT(grpc_gcp_handshaker_req_param_add_record_protocol( + req, grpc_gcp_HandshakeProtocol_TLS, "ALTSRP_GCM_AES128")); + GPR_ASSERT(grpc_gcp_handshaker_req_param_add_local_identity_service_account( + req, grpc_gcp_HandshakeProtocol_TLS, "foo@google.com")); + GPR_ASSERT(grpc_gcp_handshaker_req_param_add_local_identity_hostname( + req, grpc_gcp_HandshakeProtocol_TLS, "yihuaz0.mtv.corp.google.com")); + GPR_ASSERT(grpc_gcp_handshaker_req_param_add_record_protocol( + req, grpc_gcp_HandshakeProtocol_ALTS, "ALTSRP_GCM_AES128")); + GPR_ASSERT(grpc_gcp_handshaker_req_param_add_local_identity_hostname( + req, grpc_gcp_HandshakeProtocol_ALTS, "www.amazon.com")); + GPR_ASSERT(grpc_gcp_handshaker_req_set_rpc_versions( + req, max_rpc_version_major, max_rpc_version_minor, min_rpc_version_major, + min_rpc_version_minor)); + + GPR_ASSERT(grpc_gcp_handshaker_req_encode(req, &encoded_req)); + GPR_ASSERT(grpc_gcp_handshaker_req_decode(encoded_req, decoded_req)); + GPR_ASSERT(grpc_gcp_handshaker_req_equals(req, decoded_req)); + grpc_gcp_handshaker_req_destroy(req); + grpc_gcp_handshaker_req_destroy(decoded_req); + grpc_slice_unref(encoded_req); + + /* handshaker_resp. */ + grpc_gcp_handshaker_resp* resp = grpc_gcp_handshaker_resp_create(); + grpc_gcp_handshaker_resp* decoded_resp = grpc_gcp_handshaker_resp_create(); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_out_frames(resp, out_frames, + strlen(out_frames))); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_bytes_consumed(resp, 1024)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_application_protocol(resp, "http")); + GPR_ASSERT( + grpc_gcp_handshaker_resp_set_record_protocol(resp, "ALTSRP_GCM_AES128")); + GPR_ASSERT( + grpc_gcp_handshaker_resp_set_key_data(resp, key_data, strlen(key_data))); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_local_identity_hostname( + resp, "www.faceboook.com")); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_identity_hostname( + resp, "www.amazon.com")); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_channel_open( + resp, false /* channel_open */)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_code(resp, 1023)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_details(resp, details)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_rpc_versions( + resp, max_rpc_version_major, max_rpc_version_minor, min_rpc_version_major, + min_rpc_version_minor)); + grpc_slice encoded_resp; + GPR_ASSERT(grpc_gcp_handshaker_resp_encode(resp, &encoded_resp)); + GPR_ASSERT(grpc_gcp_handshaker_resp_decode(encoded_resp, decoded_resp)); + GPR_ASSERT(grpc_gcp_handshaker_resp_equals(resp, decoded_resp)); + grpc_gcp_handshaker_resp_destroy(resp); + grpc_gcp_handshaker_resp_destroy(decoded_resp); + grpc_slice_unref(encoded_resp); + /* Test invalid arguments. */ + GPR_ASSERT(!grpc_gcp_handshaker_req_set_in_bytes(nullptr, in_bytes, + strlen(in_bytes))); + GPR_ASSERT(!grpc_gcp_handshaker_req_param_add_record_protocol( + req, grpc_gcp_HandshakeProtocol_TLS, nullptr)); + GPR_ASSERT(!grpc_gcp_handshaker_req_param_add_local_identity_service_account( + nullptr, grpc_gcp_HandshakeProtocol_TLS, nullptr)); + GPR_ASSERT(!grpc_gcp_handshaker_resp_set_record_protocol(nullptr, nullptr)); +} diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.cc b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.cc new file mode 100644 index 0000000000..ecca04defa --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.cc @@ -0,0 +1,642 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" + +const size_t kHandshakeProtocolNum = 3; + +grpc_gcp_handshaker_req* grpc_gcp_handshaker_decoded_req_create( + grpc_gcp_handshaker_req_type type) { + grpc_gcp_handshaker_req* req = + static_cast<grpc_gcp_handshaker_req*>(gpr_zalloc(sizeof(*req))); + switch (type) { + case CLIENT_START_REQ: + req->has_client_start = true; + req->client_start.target_identities.funcs.decode = + decode_repeated_identity_cb; + req->client_start.application_protocols.funcs.decode = + decode_repeated_string_cb; + req->client_start.record_protocols.funcs.decode = + decode_repeated_string_cb; + req->client_start.local_identity.hostname.funcs.decode = + decode_string_or_bytes_cb; + req->client_start.local_identity.service_account.funcs.decode = + decode_string_or_bytes_cb; + req->client_start.local_endpoint.ip_address.funcs.decode = + decode_string_or_bytes_cb; + req->client_start.remote_endpoint.ip_address.funcs.decode = + decode_string_or_bytes_cb; + req->client_start.target_name.funcs.decode = decode_string_or_bytes_cb; + break; + case SERVER_START_REQ: + req->has_server_start = true; + req->server_start.application_protocols.funcs.decode = + &decode_repeated_string_cb; + for (size_t i = 0; i < kHandshakeProtocolNum; i++) { + req->server_start.handshake_parameters[i] + .value.local_identities.funcs.decode = &decode_repeated_identity_cb; + req->server_start.handshake_parameters[i] + .value.record_protocols.funcs.decode = &decode_repeated_string_cb; + } + req->server_start.in_bytes.funcs.decode = decode_string_or_bytes_cb; + req->server_start.local_endpoint.ip_address.funcs.decode = + decode_string_or_bytes_cb; + req->server_start.remote_endpoint.ip_address.funcs.decode = + decode_string_or_bytes_cb; + break; + case NEXT_REQ: + req->has_next = true; + break; + } + return req; +} + +bool grpc_gcp_handshaker_resp_set_application_protocol( + grpc_gcp_handshaker_resp* resp, const char* application_protocol) { + if (resp == nullptr || application_protocol == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "handshaker_resp_set_application_protocol()."); + return false; + } + resp->has_result = true; + grpc_slice* slice = + create_slice(application_protocol, strlen(application_protocol)); + resp->result.application_protocol.arg = slice; + resp->result.application_protocol.funcs.encode = encode_string_or_bytes_cb; + return true; +} + +bool grpc_gcp_handshaker_resp_set_record_protocol( + grpc_gcp_handshaker_resp* resp, const char* record_protocol) { + if (resp == nullptr || record_protocol == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "handshaker_resp_set_record_protocol()."); + return false; + } + resp->has_result = true; + grpc_slice* slice = create_slice(record_protocol, strlen(record_protocol)); + resp->result.record_protocol.arg = slice; + resp->result.record_protocol.funcs.encode = encode_string_or_bytes_cb; + return true; +} + +bool grpc_gcp_handshaker_resp_set_key_data(grpc_gcp_handshaker_resp* resp, + const char* key_data, size_t size) { + if (resp == nullptr || key_data == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to handshaker_resp_set_key_data()."); + return false; + } + resp->has_result = true; + grpc_slice* slice = create_slice(key_data, size); + resp->result.key_data.arg = slice; + resp->result.key_data.funcs.encode = encode_string_or_bytes_cb; + return true; +} + +static void set_identity_hostname(grpc_gcp_identity* identity, + const char* hostname) { + grpc_slice* slice = create_slice(hostname, strlen(hostname)); + identity->hostname.arg = slice; + identity->hostname.funcs.encode = encode_string_or_bytes_cb; +} + +static void set_identity_service_account(grpc_gcp_identity* identity, + const char* service_account) { + grpc_slice* slice = create_slice(service_account, strlen(service_account)); + identity->service_account.arg = slice; + identity->service_account.funcs.encode = encode_string_or_bytes_cb; +} + +bool grpc_gcp_handshaker_resp_set_local_identity_hostname( + grpc_gcp_handshaker_resp* resp, const char* hostname) { + if (resp == nullptr || hostname == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_gcp_handshaker_resp_set_local_identity_hostname()."); + return false; + } + resp->has_result = true; + resp->result.has_local_identity = true; + set_identity_hostname(&resp->result.local_identity, hostname); + return true; +} + +bool grpc_gcp_handshaker_resp_set_local_identity_service_account( + grpc_gcp_handshaker_resp* resp, const char* service_account) { + if (resp == nullptr || service_account == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_gcp_handshaker_resp_set_local_identity_service_account()."); + return false; + } + resp->has_result = true; + resp->result.has_local_identity = true; + set_identity_service_account(&resp->result.local_identity, service_account); + return true; +} + +bool grpc_gcp_handshaker_resp_set_peer_identity_hostname( + grpc_gcp_handshaker_resp* resp, const char* hostname) { + if (resp == nullptr || hostname == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_gcp_handshaker_resp_set_peer_identity_hostname()."); + return false; + } + resp->has_result = true; + resp->result.has_peer_identity = true; + set_identity_hostname(&resp->result.peer_identity, hostname); + return true; +} + +bool grpc_gcp_handshaker_resp_set_peer_identity_service_account( + grpc_gcp_handshaker_resp* resp, const char* service_account) { + if (resp == nullptr || service_account == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_gcp_handshaker_resp_set_peer_identity_service_account()."); + return false; + } + resp->has_result = true; + resp->result.has_peer_identity = true; + set_identity_service_account(&resp->result.peer_identity, service_account); + return true; +} + +bool grpc_gcp_handshaker_resp_set_channel_open(grpc_gcp_handshaker_resp* resp, + bool keep_channel_open) { + if (resp == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr argument to " + "grpc_gcp_handshaker_resp_set_channel_open()."); + return false; + } + resp->has_result = true; + resp->result.has_keep_channel_open = true; + resp->result.keep_channel_open = keep_channel_open; + return true; +} + +bool grpc_gcp_handshaker_resp_set_code(grpc_gcp_handshaker_resp* resp, + uint32_t code) { + if (resp == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr argument to grpc_gcp_handshaker_resp_set_code()."); + return false; + } + resp->has_status = true; + resp->status.has_code = true; + resp->status.code = code; + return true; +} + +bool grpc_gcp_handshaker_resp_set_details(grpc_gcp_handshaker_resp* resp, + const char* details) { + if (resp == nullptr || details == nullptr) { + gpr_log( + GPR_ERROR, + "Invalid nullptr arguments to grpc_gcp_handshaker_resp_set_details()."); + return false; + } + resp->has_status = true; + grpc_slice* slice = create_slice(details, strlen(details)); + resp->status.details.arg = slice; + resp->status.details.funcs.encode = encode_string_or_bytes_cb; + return true; +} + +bool grpc_gcp_handshaker_resp_set_out_frames(grpc_gcp_handshaker_resp* resp, + const char* out_frames, + size_t size) { + if (resp == nullptr || out_frames == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to " + "grpc_gcp_handshaker_resp_set_out_frames()."); + return false; + } + grpc_slice* slice = create_slice(out_frames, size); + resp->out_frames.arg = slice; + resp->out_frames.funcs.encode = encode_string_or_bytes_cb; + return true; +} + +bool grpc_gcp_handshaker_resp_set_bytes_consumed(grpc_gcp_handshaker_resp* resp, + int32_t bytes_consumed) { + if (resp == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr argument to " + "grpc_gcp_handshaker_resp_set_bytes_consumed()."); + return false; + } + resp->has_bytes_consumed = true; + resp->bytes_consumed = bytes_consumed; + return true; +} + +bool grpc_gcp_handshaker_resp_set_peer_rpc_versions( + grpc_gcp_handshaker_resp* resp, uint32_t max_major, uint32_t max_minor, + uint32_t min_major, uint32_t min_minor) { + if (resp == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr argument to " + "grpc_gcp_handshaker_resp_set_peer_rpc_versions()."); + return false; + } + resp->has_result = true; + resp->result.has_peer_rpc_versions = true; + grpc_gcp_rpc_protocol_versions* versions = &resp->result.peer_rpc_versions; + versions->has_max_rpc_version = true; + versions->has_min_rpc_version = true; + versions->max_rpc_version.has_major = true; + versions->max_rpc_version.has_minor = true; + versions->min_rpc_version.has_major = true; + versions->min_rpc_version.has_minor = true; + versions->max_rpc_version.major = max_major; + versions->max_rpc_version.minor = max_minor; + versions->min_rpc_version.major = min_major; + versions->min_rpc_version.minor = min_minor; + return true; +} + +bool grpc_gcp_handshaker_resp_encode(grpc_gcp_handshaker_resp* resp, + grpc_slice* slice) { + if (resp == nullptr || slice == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr arguments to grpc_gcp_handshaker_resp_encode()."); + return false; + } + pb_ostream_t size_stream; + memset(&size_stream, 0, sizeof(pb_ostream_t)); + if (!pb_encode(&size_stream, grpc_gcp_HandshakerResp_fields, resp)) { + gpr_log(GPR_ERROR, "nanopb error: %s", PB_GET_ERROR(&size_stream)); + return false; + } + size_t encoded_length = size_stream.bytes_written; + *slice = grpc_slice_malloc(encoded_length); + pb_ostream_t output_stream = + pb_ostream_from_buffer(GRPC_SLICE_START_PTR(*slice), encoded_length); + if (!pb_encode(&output_stream, grpc_gcp_HandshakerResp_fields, resp)) { + gpr_log(GPR_ERROR, "nanopb error: %s", PB_GET_ERROR(&size_stream)); + return false; + } + return true; +} + +bool grpc_gcp_handshaker_req_decode(grpc_slice slice, + grpc_gcp_handshaker_req* req) { + if (req == nullptr) { + gpr_log(GPR_ERROR, + "Invalid nullptr argument to grpc_gcp_handshaker_req_decode()."); + return false; + } + pb_istream_t stream = pb_istream_from_buffer(GRPC_SLICE_START_PTR(slice), + GRPC_SLICE_LENGTH(slice)); + req->next.in_bytes.funcs.decode = decode_string_or_bytes_cb; + if (!pb_decode(&stream, grpc_gcp_HandshakerReq_fields, req)) { + gpr_log(GPR_ERROR, "nanopb error: %s", PB_GET_ERROR(&stream)); + return false; + } + return true; +} + +/* Check equality of a pair of grpc_slice fields. */ +static bool slice_equals(grpc_slice* l_slice, grpc_slice* r_slice) { + if (l_slice == nullptr && r_slice == nullptr) { + return true; + } + if (l_slice != nullptr && r_slice != nullptr) { + return grpc_slice_eq(*l_slice, *r_slice); + } + return false; +} + +/* Check equality of a pair of grpc_gcp_identity fields. */ +static bool handshaker_identity_equals(const grpc_gcp_identity* l_id, + const grpc_gcp_identity* r_id) { + if (!((l_id->hostname.arg != nullptr) != (r_id->hostname.arg != nullptr))) { + if (l_id->hostname.arg != nullptr) { + return slice_equals(static_cast<grpc_slice*>(l_id->hostname.arg), + static_cast<grpc_slice*>(r_id->hostname.arg)); + } + } else { + return false; + } + if (!((l_id->service_account.arg != nullptr) != + (r_id->service_account.arg != nullptr))) { + if (l_id->service_account.arg != nullptr) { + return slice_equals(static_cast<grpc_slice*>(l_id->service_account.arg), + static_cast<grpc_slice*>(r_id->service_account.arg)); + } + } else { + return false; + } + return true; +} + +static bool handshaker_rpc_versions_equals( + const grpc_gcp_rpc_protocol_versions* l_version, + const grpc_gcp_rpc_protocol_versions* r_version) { + bool result = true; + result &= + (l_version->max_rpc_version.major == r_version->max_rpc_version.major); + result &= + (l_version->max_rpc_version.minor == r_version->max_rpc_version.minor); + result &= + (l_version->min_rpc_version.major == r_version->min_rpc_version.major); + result &= + (l_version->min_rpc_version.minor == r_version->min_rpc_version.minor); + return result; +} + +/* Check equality of a pair of grpc_gcp_endpoint fields. */ +static bool handshaker_endpoint_equals(const grpc_gcp_endpoint* l_end, + const grpc_gcp_endpoint* r_end) { + bool result = true; + result &= (l_end->port == r_end->port); + result &= (l_end->protocol == r_end->protocol); + if (!((l_end->ip_address.arg != nullptr) != + (r_end->ip_address.arg != nullptr))) { + if (l_end->ip_address.arg != nullptr) { + result &= slice_equals(static_cast<grpc_slice*>(l_end->ip_address.arg), + static_cast<grpc_slice*>(r_end->ip_address.arg)); + } + } else { + return false; + } + return result; +} +/** + * Check if a specific repeated field (i.e., target) is contained in a repeated + * field list (i.e., head). + */ +static bool repeated_field_list_contains_identity( + const repeated_field* head, const repeated_field* target) { + repeated_field* field = const_cast<repeated_field*>(head); + while (field != nullptr) { + if (handshaker_identity_equals( + static_cast<const grpc_gcp_identity*>(field->data), + static_cast<const grpc_gcp_identity*>(target->data))) { + return true; + } + field = field->next; + } + return false; +} + +static bool repeated_field_list_contains_string(const repeated_field* head, + const repeated_field* target) { + repeated_field* field = const_cast<repeated_field*>(head); + while (field != nullptr) { + if (slice_equals((grpc_slice*)field->data, (grpc_slice*)target->data)) { + return true; + } + field = field->next; + } + return false; +} + +/* Return a length of repeated field list. */ +static size_t repeated_field_list_get_length(const repeated_field* head) { + repeated_field* field = const_cast<repeated_field*>(head); + size_t len = 0; + while (field != nullptr) { + len++; + field = field->next; + } + return len; +} + +/** + * Check if a pair of repeated field lists contain the same set of repeated + * fields. + */ +static bool repeated_field_list_equals_identity(const repeated_field* l_head, + const repeated_field* r_head) { + if (repeated_field_list_get_length(l_head) != + repeated_field_list_get_length(r_head)) { + return false; + } + repeated_field* field = const_cast<repeated_field*>(l_head); + repeated_field* head = const_cast<repeated_field*>(r_head); + while (field != nullptr) { + if (!repeated_field_list_contains_identity(head, field)) { + return false; + } + field = field->next; + } + return true; +} + +static bool repeated_field_list_equals_string(const repeated_field* l_head, + const repeated_field* r_head) { + if (repeated_field_list_get_length(l_head) != + repeated_field_list_get_length(r_head)) { + return false; + } + repeated_field* field = const_cast<repeated_field*>(l_head); + repeated_field* head = const_cast<repeated_field*>(r_head); + while (field != nullptr) { + if (!repeated_field_list_contains_string(head, field)) { + return false; + } + field = field->next; + } + return true; +} + +/* Check equality of a pair of ALTS client_start handshake requests. */ +bool grpc_gcp_handshaker_client_start_req_equals( + grpc_gcp_start_client_handshake_req* l_req, + grpc_gcp_start_client_handshake_req* r_req) { + bool result = true; + /* Compare handshake_security_protocol. */ + result &= + l_req->handshake_security_protocol == r_req->handshake_security_protocol; + /* Compare application_protocols, record_protocols, and target_identities. */ + result &= repeated_field_list_equals_string( + static_cast<const repeated_field*>(l_req->application_protocols.arg), + static_cast<const repeated_field*>(r_req->application_protocols.arg)); + result &= repeated_field_list_equals_string( + static_cast<const repeated_field*>(l_req->record_protocols.arg), + static_cast<const repeated_field*>(r_req->record_protocols.arg)); + result &= repeated_field_list_equals_identity( + static_cast<const repeated_field*>(l_req->target_identities.arg), + static_cast<const repeated_field*>(r_req->target_identities.arg)); + if ((l_req->has_local_identity ^ r_req->has_local_identity) | + (l_req->has_local_endpoint ^ r_req->has_local_endpoint) | + ((l_req->has_remote_endpoint ^ r_req->has_remote_endpoint)) | + (l_req->has_rpc_versions ^ r_req->has_rpc_versions)) { + return false; + } + /* Compare local_identity, local_endpoint, and remote_endpoint. */ + if (l_req->has_local_identity) { + result &= handshaker_identity_equals(&l_req->local_identity, + &r_req->local_identity); + } + if (l_req->has_local_endpoint) { + result &= handshaker_endpoint_equals(&l_req->local_endpoint, + &r_req->local_endpoint); + } + if (l_req->has_remote_endpoint) { + result &= handshaker_endpoint_equals(&l_req->remote_endpoint, + &r_req->remote_endpoint); + } + if (l_req->has_rpc_versions) { + result &= handshaker_rpc_versions_equals(&l_req->rpc_versions, + &r_req->rpc_versions); + } + return result; +} + +/* Check equality of a pair of ALTS server_start handshake requests. */ +bool grpc_gcp_handshaker_server_start_req_equals( + grpc_gcp_start_server_handshake_req* l_req, + grpc_gcp_start_server_handshake_req* r_req) { + bool result = true; + /* Compare application_protocols. */ + result &= repeated_field_list_equals_string( + static_cast<const repeated_field*>(l_req->application_protocols.arg), + static_cast<const repeated_field*>(r_req->application_protocols.arg)); + /* Compare handshake_parameters. */ + size_t i = 0, j = 0; + result &= + (l_req->handshake_parameters_count == r_req->handshake_parameters_count); + for (i = 0; i < l_req->handshake_parameters_count; i++) { + bool found = false; + for (j = 0; j < r_req->handshake_parameters_count; j++) { + if (l_req->handshake_parameters[i].key == + r_req->handshake_parameters[j].key) { + found = true; + result &= repeated_field_list_equals_string( + static_cast<const repeated_field*>( + l_req->handshake_parameters[i].value.record_protocols.arg), + static_cast<const repeated_field*>( + r_req->handshake_parameters[j].value.record_protocols.arg)); + result &= repeated_field_list_equals_identity( + static_cast<const repeated_field*>( + l_req->handshake_parameters[i].value.local_identities.arg), + static_cast<const repeated_field*>( + r_req->handshake_parameters[j].value.local_identities.arg)); + } + } + if (!found) { + return false; + } + } + /* Compare in_bytes, local_endpoint, remote_endpoint. */ + result &= slice_equals(static_cast<grpc_slice*>(l_req->in_bytes.arg), + static_cast<grpc_slice*>(r_req->in_bytes.arg)); + if ((l_req->has_local_endpoint ^ r_req->has_local_endpoint) | + (l_req->has_remote_endpoint ^ r_req->has_remote_endpoint) | + (l_req->has_rpc_versions ^ r_req->has_rpc_versions)) + return false; + if (l_req->has_local_endpoint) { + result &= handshaker_endpoint_equals(&l_req->local_endpoint, + &r_req->local_endpoint); + } + if (l_req->has_remote_endpoint) { + result &= handshaker_endpoint_equals(&l_req->remote_endpoint, + &r_req->remote_endpoint); + } + if (l_req->has_rpc_versions) { + result &= handshaker_rpc_versions_equals(&l_req->rpc_versions, + &r_req->rpc_versions); + } + return result; +} + +/* Check equality of a pair of ALTS handshake requests. */ +bool grpc_gcp_handshaker_req_equals(grpc_gcp_handshaker_req* l_req, + grpc_gcp_handshaker_req* r_req) { + if (l_req->has_next && r_req->has_next) { + return slice_equals(static_cast<grpc_slice*>(l_req->next.in_bytes.arg), + static_cast<grpc_slice*>(r_req->next.in_bytes.arg)); + } else if (l_req->has_client_start && r_req->has_client_start) { + return grpc_gcp_handshaker_client_start_req_equals(&l_req->client_start, + &r_req->client_start); + } else if (l_req->has_server_start && r_req->has_server_start) { + return grpc_gcp_handshaker_server_start_req_equals(&l_req->server_start, + &r_req->server_start); + } + return false; +} + +/* Check equality of a pair of ALTS handshake results. */ +bool grpc_gcp_handshaker_resp_result_equals( + grpc_gcp_handshaker_result* l_result, + grpc_gcp_handshaker_result* r_result) { + bool result = true; + /* Compare application_protocol, record_protocol, and key_data. */ + result &= slice_equals( + static_cast<grpc_slice*>(l_result->application_protocol.arg), + static_cast<grpc_slice*>(r_result->application_protocol.arg)); + result &= + slice_equals(static_cast<grpc_slice*>(l_result->record_protocol.arg), + static_cast<grpc_slice*>(r_result->record_protocol.arg)); + result &= slice_equals(static_cast<grpc_slice*>(l_result->key_data.arg), + static_cast<grpc_slice*>(r_result->key_data.arg)); + /* Compare local_identity, peer_identity, and keep_channel_open. */ + if ((l_result->has_local_identity ^ r_result->has_local_identity) | + (l_result->has_peer_identity ^ r_result->has_peer_identity) | + (l_result->has_peer_rpc_versions ^ r_result->has_peer_rpc_versions)) { + return false; + } + if (l_result->has_local_identity) { + result &= handshaker_identity_equals(&l_result->local_identity, + &r_result->local_identity); + } + if (l_result->has_peer_identity) { + result &= handshaker_identity_equals(&l_result->peer_identity, + &r_result->peer_identity); + } + if (l_result->has_peer_rpc_versions) { + result &= handshaker_rpc_versions_equals(&l_result->peer_rpc_versions, + &r_result->peer_rpc_versions); + } + result &= (l_result->keep_channel_open == r_result->keep_channel_open); + return result; +} + +/* Check equality of a pair of ALTS handshake responses. */ +bool grpc_gcp_handshaker_resp_equals(grpc_gcp_handshaker_resp* l_resp, + grpc_gcp_handshaker_resp* r_resp) { + bool result = true; + /* Compare out_frames and bytes_consumed. */ + result &= slice_equals(static_cast<grpc_slice*>(l_resp->out_frames.arg), + static_cast<grpc_slice*>(r_resp->out_frames.arg)); + result &= (l_resp->bytes_consumed == r_resp->bytes_consumed); + /* Compare result and status. */ + if ((l_resp->has_result ^ r_resp->has_result) | + (l_resp->has_status ^ r_resp->has_status)) { + return false; + } + if (l_resp->has_result) { + result &= grpc_gcp_handshaker_resp_result_equals(&l_resp->result, + &r_resp->result); + } + if (l_resp->has_status) { + result &= (l_resp->status.code == r_resp->status.code); + result &= + slice_equals(static_cast<grpc_slice*>(l_resp->status.details.arg), + static_cast<grpc_slice*>(r_resp->status.details.arg)); + } + return result; +} diff --git a/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h new file mode 100644 index 0000000000..2fcbb4ea99 --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h @@ -0,0 +1,143 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_TEST_CORE_TSI_ALTS_HANDSHAKER_ALTS_HANDSHAKER_SERVICE_API_TEST_LIB_H +#define GRPC_TEST_CORE_TSI_ALTS_HANDSHAKER_ALTS_HANDSHAKER_SERVICE_API_TEST_LIB_H + +#include "src/core/tsi/alts/handshaker/alts_handshaker_service_api.h" +#include "src/core/tsi/alts/handshaker/alts_handshaker_service_api_util.h" +#include "src/core/tsi/alts/handshaker/transport_security_common_api.h" + +/** + * The first part of this file contains function signatures for de-serializing + * ALTS handshake requests and setting/serializing ALTS handshake responses, + * which simulate the behaviour of grpc server that runs ALTS handshaker + * service. + */ + +/** + * This method creates a ALTS handshaker request that is used to hold + * de-serialized result. + */ +grpc_gcp_handshaker_req* grpc_gcp_handshaker_decoded_req_create( + grpc_gcp_handshaker_req_type type); + +/* This method de-serializes a ALTS handshaker request. */ +bool grpc_gcp_handshaker_req_decode(grpc_slice slice, + grpc_gcp_handshaker_req* req); + +/* This method serializes a ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_encode(grpc_gcp_handshaker_resp* resp, + grpc_slice* slice); + +/* This method sets application protocol of ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_application_protocol( + grpc_gcp_handshaker_resp* resp, const char* application_protocol); + +/* This method sets record protocol of ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_record_protocol( + grpc_gcp_handshaker_resp* resp, const char* record_protocol); + +/* This method sets key_data of ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_key_data(grpc_gcp_handshaker_resp* resp, + const char* key_data, size_t size); + +/* This method sets local identity's hostname for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_local_identity_hostname( + grpc_gcp_handshaker_resp* resp, const char* hostname); + +/** + * This method sets local identity's service account for ALTS handshaker + * response. + */ +bool grpc_gcp_handshaker_resp_set_local_identity_service_account( + grpc_gcp_handshaker_resp* resp, const char* service_account); + +/* This method sets peer identity's hostname for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_peer_identity_hostname( + grpc_gcp_handshaker_resp* resp, const char* hostname); + +/** + * This method sets peer identity's service account for ALTS handshaker + * response. + */ +bool grpc_gcp_handshaker_resp_set_peer_identity_service_account( + grpc_gcp_handshaker_resp* resp, const char* service_account); + +/* This method sets keep_channel_open for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_channel_open(grpc_gcp_handshaker_resp* resp, + bool keep_channel_open); + +/* This method sets code for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_code(grpc_gcp_handshaker_resp* resp, + uint32_t code); + +/* This method sets details for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_details(grpc_gcp_handshaker_resp* resp, + const char* details); + +/* This method sets out_frames for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_out_frames(grpc_gcp_handshaker_resp* resp, + const char* out_frames, + size_t size); + +/* This method sets peer_rpc_versions for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_peer_rpc_versions( + grpc_gcp_handshaker_resp* resp, uint32_t max_major, uint32_t max_minor, + uint32_t min_major, uint32_t min_minor); + +/* This method sets bytes_consumed for ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_set_bytes_consumed(grpc_gcp_handshaker_resp* resp, + int32_t bytes_consumed); + +/* This method serializes ALTS handshaker response. */ +bool grpc_gcp_handshaker_resp_encode(grpc_gcp_handshaker_resp* resp, + grpc_slice* slice); + +/* This method de-serializes ALTS handshaker request. */ +bool grpc_gcp_handshaker_req_decode(grpc_slice slice, + grpc_gcp_handshaker_req* req); + +/** + * The second part contains function signatures for checking equality of a pair + * of ALTS handshake requests/responses. + */ + +/* This method checks equality of two client_start handshaker requests. */ +bool grpc_gcp_handshaker_client_start_req_equals( + grpc_gcp_start_client_handshake_req* l_req, + grpc_gcp_start_client_handshake_req* r_req); + +/* This method checks equality of two server_start handshaker requests. */ +bool grpc_gcp_handshaker_server_start_req_equals( + grpc_gcp_start_server_handshake_req* l_req, + grpc_gcp_start_server_handshake_req* r_req); + +/* This method checks equality of two ALTS handshaker requests. */ +bool grpc_gcp_handshaker_req_equals(grpc_gcp_handshaker_req* l_req, + grpc_gcp_handshaker_req* r_req); + +/* This method checks equality of two handshaker response results. */ +bool grpc_gcp_handshaker_resp_result_equals( + grpc_gcp_handshaker_result* l_result, grpc_gcp_handshaker_result* r_result); + +/* This method checks equality of two ALTS handshaker responses. */ +bool grpc_gcp_handshaker_resp_equals(grpc_gcp_handshaker_resp* l_resp, + grpc_gcp_handshaker_resp* r_resp); + +#endif // GRPC_TEST_CORE_TSI_ALTS_HANDSHAKER_ALTS_HANDSHAKER_SERVICE_API_TEST_LIB_H diff --git a/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc new file mode 100644 index 0000000000..95724f84f4 --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_tsi_handshaker_test.cc @@ -0,0 +1,682 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <stdio.h> +#include <stdlib.h> + +#include <grpc/grpc.h> +#include <grpc/support/sync.h> + +#include "src/core/lib/gprpp/thd.h" +#include "src/core/tsi/alts/handshaker/alts_handshaker_client.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_event.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker.h" +#include "src/core/tsi/alts/handshaker/alts_tsi_handshaker_private.h" +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" + +#define ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES "Hello World" +#define ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME "Hello Google" +#define ALTS_TSI_HANDSHAKER_TEST_CONSUMED_BYTES "Hello " +#define ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES "Google" +#define ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY "chapi@service.google.com" +#define ALTS_TSI_HANDSHAKER_TEST_KEY_DATA \ + "ABCDEFGHIJKLMNOPABCDEFGHIJKLMNOPABCDEFGHIJKL" +#define ALTS_TSI_HANDSHAKER_TEST_BUFFER_SIZE 100 +#define ALTS_TSI_HANDSHAKER_TEST_SLEEP_TIME_IN_SECONDS 2 +#define ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR 3 +#define ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR 2 +#define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR 2 +#define ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR 1 + +using grpc_core::internal:: + alts_tsi_handshaker_get_has_sent_start_message_for_testing; +using grpc_core::internal::alts_tsi_handshaker_get_is_client_for_testing; +using grpc_core::internal::alts_tsi_handshaker_get_recv_bytes_for_testing; +using grpc_core::internal::alts_tsi_handshaker_set_client_for_testing; +using grpc_core::internal::alts_tsi_handshaker_set_recv_bytes_for_testing; + +/* ALTS mock notification. */ +typedef struct notification { + gpr_cv cv; + gpr_mu mu; + bool notified; +} notification; + +/* ALTS mock handshaker client. */ +typedef struct alts_mock_handshaker_client { + alts_handshaker_client base; + bool used_for_success_test; +} alts_mock_handshaker_client; + +/* Type of ALTS handshaker response. */ +typedef enum { + INVALID, + FAILED, + CLIENT_START, + SERVER_START, + CLIENT_NEXT, + SERVER_NEXT, +} alts_handshaker_response_type; + +static alts_tsi_event* client_start_event; +static alts_tsi_event* client_next_event; +static alts_tsi_event* server_start_event; +static alts_tsi_event* server_next_event; +static notification caller_to_tsi_notification; +static notification tsi_to_caller_notification; + +static void notification_init(notification* n) { + gpr_mu_init(&n->mu); + gpr_cv_init(&n->cv); + n->notified = false; +} + +static void notification_destroy(notification* n) { + gpr_mu_destroy(&n->mu); + gpr_cv_destroy(&n->cv); +} + +static void signal(notification* n) { + gpr_mu_lock(&n->mu); + n->notified = true; + gpr_cv_signal(&n->cv); + gpr_mu_unlock(&n->mu); +} + +static void wait(notification* n) { + gpr_mu_lock(&n->mu); + while (!n->notified) { + gpr_cv_wait(&n->cv, &n->mu, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + n->notified = false; + gpr_mu_unlock(&n->mu); +} + +/** + * This method mocks ALTS handshaker service to generate handshaker response + * for a specific request. + */ +static grpc_byte_buffer* generate_handshaker_response( + alts_handshaker_response_type type) { + grpc_gcp_handshaker_resp* resp = grpc_gcp_handshaker_resp_create(); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_code(resp, 0)); + switch (type) { + case INVALID: + break; + case CLIENT_START: + case SERVER_START: + GPR_ASSERT(grpc_gcp_handshaker_resp_set_out_frames( + resp, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME))); + break; + case CLIENT_NEXT: + GPR_ASSERT(grpc_gcp_handshaker_resp_set_out_frames( + resp, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME))); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_identity_service_account( + resp, ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_bytes_consumed( + resp, strlen(ALTS_TSI_HANDSHAKER_TEST_CONSUMED_BYTES))); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_key_data( + resp, ALTS_TSI_HANDSHAKER_TEST_KEY_DATA, + strlen(ALTS_TSI_HANDSHAKER_TEST_KEY_DATA))); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_rpc_versions( + resp, ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR)); + break; + case SERVER_NEXT: + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_identity_service_account( + resp, ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY)); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_bytes_consumed( + resp, strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME))); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_key_data( + resp, ALTS_TSI_HANDSHAKER_TEST_KEY_DATA, + strlen(ALTS_TSI_HANDSHAKER_TEST_KEY_DATA))); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_peer_rpc_versions( + resp, ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MAX_RPC_VERSION_MINOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MAJOR, + ALTS_TSI_HANDSHAKER_TEST_MIN_RPC_VERSION_MINOR)); + break; + case FAILED: + GPR_ASSERT( + grpc_gcp_handshaker_resp_set_code(resp, 3 /* INVALID ARGUMENT */)); + break; + } + grpc_slice slice; + GPR_ASSERT(grpc_gcp_handshaker_resp_encode(resp, &slice)); + if (type == INVALID) { + grpc_slice bad_slice = + grpc_slice_split_head(&slice, GRPC_SLICE_LENGTH(slice) - 1); + grpc_slice_unref(slice); + slice = grpc_slice_ref(bad_slice); + grpc_slice_unref(bad_slice); + } + grpc_byte_buffer* buffer = + grpc_raw_byte_buffer_create(&slice, 1 /* number of slices */); + grpc_slice_unref(slice); + grpc_gcp_handshaker_resp_destroy(resp); + return buffer; +} + +static void check_must_not_be_called(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(0); +} + +static void on_client_start_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + bytes_to_send_size) == 0); + GPR_ASSERT(result == nullptr); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == + TSI_INVALID_ARGUMENT); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_INVALID_ARGUMENT); + /* Validate unused bytes. */ + const unsigned char* unused_bytes = nullptr; + size_t unused_bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &unused_bytes, + &unused_bytes_size) == + TSI_INVALID_ARGUMENT); + signal(&tsi_to_caller_notification); +} + +static void on_server_start_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + bytes_to_send_size) == 0); + GPR_ASSERT(result == nullptr); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == + TSI_INVALID_ARGUMENT); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_INVALID_ARGUMENT); + /* Validate unused bytes. */ + const unsigned char* unused_bytes = nullptr; + size_t unused_bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &unused_bytes, + &unused_bytes_size) == + TSI_INVALID_ARGUMENT); + signal(&tsi_to_caller_notification); +} + +static void on_client_next_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == strlen(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME)); + GPR_ASSERT(memcmp(bytes_to_send, ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME, + bytes_to_send_size) == 0); + GPR_ASSERT(result != nullptr); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); + GPR_ASSERT(peer.property_count == kTsiAltsNumOfPeerProperties); + GPR_ASSERT(memcmp(TSI_ALTS_CERTIFICATE_TYPE, peer.properties[0].value.data, + peer.properties[0].value.length) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, + peer.properties[1].value.data, + peer.properties[1].value.length) == 0); + tsi_peer_destruct(&peer); + /* Validate unused bytes. */ + const unsigned char* bytes = nullptr; + size_t bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &bytes, + &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == strlen(ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES)); + GPR_ASSERT(memcmp(bytes, ALTS_TSI_HANDSHAKER_TEST_REMAIN_BYTES, bytes_size) == + 0); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_OK); + GPR_ASSERT(protector != nullptr); + tsi_frame_protector_destroy(protector); + tsi_handshaker_result_destroy(result); + signal(&tsi_to_caller_notification); +} + +static void on_server_next_success_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_OK); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(result != nullptr); + /* Validate peer identity. */ + tsi_peer peer; + GPR_ASSERT(tsi_handshaker_result_extract_peer(result, &peer) == TSI_OK); + GPR_ASSERT(peer.property_count == kTsiAltsNumOfPeerProperties); + GPR_ASSERT(memcmp(TSI_ALTS_CERTIFICATE_TYPE, peer.properties[0].value.data, + peer.properties[0].value.length) == 0); + GPR_ASSERT(memcmp(ALTS_TSI_HANDSHAKER_TEST_PEER_IDENTITY, + peer.properties[1].value.data, + peer.properties[1].value.length) == 0); + tsi_peer_destruct(&peer); + /* Validate unused bytes. */ + const unsigned char* bytes = nullptr; + size_t bytes_size = 0; + GPR_ASSERT(tsi_handshaker_result_get_unused_bytes(result, &bytes, + &bytes_size) == TSI_OK); + GPR_ASSERT(bytes_size == 0); + GPR_ASSERT(bytes == nullptr); + /* Validate frame protector. */ + tsi_frame_protector* protector = nullptr; + GPR_ASSERT(tsi_handshaker_result_create_frame_protector( + result, nullptr, &protector) == TSI_OK); + GPR_ASSERT(protector != nullptr); + tsi_frame_protector_destroy(protector); + tsi_handshaker_result_destroy(result); + signal(&tsi_to_caller_notification); +} + +static tsi_result mock_client_start(alts_handshaker_client* self, + alts_tsi_event* event) { + alts_mock_handshaker_client* client = + reinterpret_cast<alts_mock_handshaker_client*>(self); + if (!client->used_for_success_test) { + alts_tsi_event_destroy(event); + return TSI_INTERNAL_ERROR; + } + GPR_ASSERT(event->cb == on_client_start_success_cb); + GPR_ASSERT(event->user_data == nullptr); + GPR_ASSERT(!alts_tsi_handshaker_get_has_sent_start_message_for_testing( + event->handshaker)); + /* Populate handshaker response for client_start request. */ + event->recv_buffer = generate_handshaker_response(CLIENT_START); + client_start_event = event; + signal(&caller_to_tsi_notification); + return TSI_OK; +} + +static tsi_result mock_server_start(alts_handshaker_client* self, + alts_tsi_event* event, + grpc_slice* bytes_received) { + alts_mock_handshaker_client* client = + reinterpret_cast<alts_mock_handshaker_client*>(self); + if (!client->used_for_success_test) { + alts_tsi_event_destroy(event); + return TSI_INTERNAL_ERROR; + } + GPR_ASSERT(event->cb == on_server_start_success_cb); + GPR_ASSERT(event->user_data == nullptr); + grpc_slice slice = grpc_empty_slice(); + GPR_ASSERT(grpc_slice_cmp(*bytes_received, slice) == 0); + GPR_ASSERT(!alts_tsi_handshaker_get_has_sent_start_message_for_testing( + event->handshaker)); + /* Populate handshaker response for server_start request. */ + event->recv_buffer = generate_handshaker_response(SERVER_START); + server_start_event = event; + grpc_slice_unref(slice); + signal(&caller_to_tsi_notification); + return TSI_OK; +} + +static tsi_result mock_next(alts_handshaker_client* self, alts_tsi_event* event, + grpc_slice* bytes_received) { + alts_mock_handshaker_client* client = + reinterpret_cast<alts_mock_handshaker_client*>(self); + if (!client->used_for_success_test) { + alts_tsi_event_destroy(event); + return TSI_INTERNAL_ERROR; + } + bool is_client = + alts_tsi_handshaker_get_is_client_for_testing(event->handshaker); + if (is_client) { + GPR_ASSERT(event->cb == on_client_next_success_cb); + } else { + GPR_ASSERT(event->cb == on_server_next_success_cb); + } + GPR_ASSERT(event->user_data == nullptr); + GPR_ASSERT(bytes_received != nullptr); + GPR_ASSERT(memcmp(GRPC_SLICE_START_PTR(*bytes_received), + ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + GRPC_SLICE_LENGTH(*bytes_received)) == 0); + GPR_ASSERT(grpc_slice_cmp(alts_tsi_handshaker_get_recv_bytes_for_testing( + event->handshaker), + *bytes_received) == 0); + GPR_ASSERT(alts_tsi_handshaker_get_has_sent_start_message_for_testing( + event->handshaker)); + /* Populate handshaker response for next request. */ + grpc_slice out_frame = + grpc_slice_from_static_string(ALTS_TSI_HANDSHAKER_TEST_OUT_FRAME); + if (is_client) { + event->recv_buffer = generate_handshaker_response(CLIENT_NEXT); + } else { + event->recv_buffer = generate_handshaker_response(SERVER_NEXT); + } + alts_tsi_handshaker_set_recv_bytes_for_testing(event->handshaker, &out_frame); + if (is_client) { + client_next_event = event; + } else { + server_next_event = event; + } + signal(&caller_to_tsi_notification); + grpc_slice_unref(out_frame); + return TSI_OK; +} + +static void mock_destruct(alts_handshaker_client* client) {} + +static const alts_handshaker_client_vtable vtable = { + mock_client_start, mock_server_start, mock_next, mock_destruct}; + +static alts_handshaker_client* alts_mock_handshaker_client_create( + bool used_for_success_test) { + alts_mock_handshaker_client* client = + static_cast<alts_mock_handshaker_client*>(gpr_zalloc(sizeof(*client))); + client->base.vtable = &vtable; + client->used_for_success_test = used_for_success_test; + return &client->base; +} + +static tsi_handshaker* create_test_handshaker(bool used_for_success_test, + bool is_client) { + tsi_handshaker* handshaker = nullptr; + alts_handshaker_client* client = + alts_mock_handshaker_client_create(used_for_success_test); + grpc_alts_credentials_options* options = + grpc_alts_credentials_client_options_create(); + alts_tsi_handshaker_create(options, "target_name", "lame", is_client, + &handshaker); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast<alts_tsi_handshaker*>(handshaker); + alts_tsi_handshaker_set_client_for_testing(alts_handshaker, client); + grpc_alts_credentials_options_destroy(options); + return handshaker; +} + +static void check_handshaker_next_invalid_input() { + /* Initialization. */ + tsi_handshaker* handshaker = create_test_handshaker(true, true); + /* Check nullptr handshaker. */ + GPR_ASSERT(tsi_handshaker_next(nullptr, nullptr, 0, nullptr, nullptr, nullptr, + check_must_not_be_called, + nullptr) == TSI_INVALID_ARGUMENT); + /* Check nullptr callback. */ + GPR_ASSERT(tsi_handshaker_next(handshaker, nullptr, 0, nullptr, nullptr, + nullptr, nullptr, + nullptr) == TSI_INVALID_ARGUMENT); + /* Cleanup. */ + tsi_handshaker_destroy(handshaker); +} + +static void check_handshaker_next_success() { + /** + * Create handshakers for which internal mock client is going to do + * correctness check. + */ + tsi_handshaker* client_handshaker = create_test_handshaker( + true /* used_for_success_test */, true /* is_client */); + tsi_handshaker* server_handshaker = create_test_handshaker( + true /* used_for_success_test */, false /* is_client */); + /* Client start. */ + GPR_ASSERT(tsi_handshaker_next(client_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, on_client_start_success_cb, + nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Client next. */ + GPR_ASSERT(tsi_handshaker_next( + client_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, on_client_next_success_cb, nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Server start. */ + GPR_ASSERT(tsi_handshaker_next(server_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, on_server_start_success_cb, + nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Server next. */ + GPR_ASSERT(tsi_handshaker_next( + server_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, on_server_next_success_cb, nullptr) == TSI_ASYNC); + wait(&tsi_to_caller_notification); + /* Cleanup. */ + tsi_handshaker_destroy(server_handshaker); + tsi_handshaker_destroy(client_handshaker); +} + +static void check_handshaker_next_failure() { + /** + * Create handshakers for which internal mock client is always going to fail. + */ + tsi_handshaker* client_handshaker = create_test_handshaker( + false /* used_for_success_test */, true /* is_client */); + tsi_handshaker* server_handshaker = create_test_handshaker( + false /* used_for_success_test */, false /* is_client */); + /* Client start. */ + GPR_ASSERT(tsi_handshaker_next(client_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Server start. */ + GPR_ASSERT(tsi_handshaker_next(server_handshaker, nullptr, 0, nullptr, + nullptr, nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Server next. */ + GPR_ASSERT(tsi_handshaker_next( + server_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Client next. */ + GPR_ASSERT(tsi_handshaker_next( + client_handshaker, + (const unsigned char*)ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES, + strlen(ALTS_TSI_HANDSHAKER_TEST_RECV_BYTES), nullptr, nullptr, + nullptr, check_must_not_be_called, + nullptr) == TSI_INTERNAL_ERROR); + /* Cleanup. */ + tsi_handshaker_destroy(server_handshaker); + tsi_handshaker_destroy(client_handshaker); +} + +static void on_invalid_input_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_INTERNAL_ERROR); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void on_failed_grpc_call_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_INTERNAL_ERROR); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_invalid_input() { + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker( + false /* used_for_success_test */, true /* is_client */); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast<alts_tsi_handshaker*>(handshaker); + grpc_byte_buffer recv_buffer; + /* Check nullptr handshaker. */ + alts_tsi_handshaker_handle_response(nullptr, &recv_buffer, GRPC_STATUS_OK, + nullptr, on_invalid_input_cb, nullptr, + true); + /* Check nullptr recv_bytes. */ + alts_tsi_handshaker_handle_response(alts_handshaker, nullptr, GRPC_STATUS_OK, + nullptr, on_invalid_input_cb, nullptr, + true); + /* Check failed grpc call made to handshaker service. */ + alts_tsi_handshaker_handle_response(alts_handshaker, &recv_buffer, + GRPC_STATUS_UNKNOWN, nullptr, + on_failed_grpc_call_cb, nullptr, true); + + alts_tsi_handshaker_handle_response(alts_handshaker, &recv_buffer, + GRPC_STATUS_OK, nullptr, + on_failed_grpc_call_cb, nullptr, false); + + /* Cleanup. */ + tsi_handshaker_destroy(handshaker); +} + +static void on_invalid_resp_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_DATA_CORRUPTED); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_invalid_resp() { + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker( + false /* used_for_success_test */, true /* is_client */); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast<alts_tsi_handshaker*>(handshaker); + /* Tests. */ + grpc_byte_buffer* recv_buffer = generate_handshaker_response(INVALID); + alts_tsi_handshaker_handle_response(alts_handshaker, recv_buffer, + GRPC_STATUS_OK, nullptr, + on_invalid_resp_cb, nullptr, true); + /* Cleanup. */ + grpc_byte_buffer_destroy(recv_buffer); + tsi_handshaker_destroy(handshaker); +} + +static void check_handle_response_success(void* unused) { + /* Client start. */ + wait(&caller_to_tsi_notification); + alts_tsi_event_dispatch_to_handshaker(client_start_event, true /* is_ok */); + alts_tsi_event_destroy(client_start_event); + /* Client next. */ + wait(&caller_to_tsi_notification); + alts_tsi_event_dispatch_to_handshaker(client_next_event, true /* is_ok */); + alts_tsi_event_destroy(client_next_event); + /* Server start. */ + wait(&caller_to_tsi_notification); + alts_tsi_event_dispatch_to_handshaker(server_start_event, true /* is_ok */); + alts_tsi_event_destroy(server_start_event); + /* Server next. */ + wait(&caller_to_tsi_notification); + alts_tsi_event_dispatch_to_handshaker(server_next_event, true /* is_ok */); + alts_tsi_event_destroy(server_next_event); +} + +static void on_failed_resp_cb(tsi_result status, void* user_data, + const unsigned char* bytes_to_send, + size_t bytes_to_send_size, + tsi_handshaker_result* result) { + GPR_ASSERT(status == TSI_INVALID_ARGUMENT); + GPR_ASSERT(user_data == nullptr); + GPR_ASSERT(bytes_to_send == nullptr); + GPR_ASSERT(bytes_to_send_size == 0); + GPR_ASSERT(result == nullptr); +} + +static void check_handle_response_failure() { + /** + * Create a handshaker at the client side, for which internal mock client is + * always going to fail. + */ + tsi_handshaker* handshaker = create_test_handshaker( + false /* used_for_success_test */, true /* is_client */); + alts_tsi_handshaker* alts_handshaker = + reinterpret_cast<alts_tsi_handshaker*>(handshaker); + /* Tests. */ + grpc_byte_buffer* recv_buffer = generate_handshaker_response(FAILED); + alts_tsi_handshaker_handle_response(alts_handshaker, recv_buffer, + GRPC_STATUS_OK, nullptr, + on_failed_resp_cb, nullptr, true); + grpc_byte_buffer_destroy(recv_buffer); + /* Cleanup. */ + tsi_handshaker_destroy(handshaker); +} + +void check_handshaker_success() { + /* Initialization. */ + notification_init(&caller_to_tsi_notification); + notification_init(&tsi_to_caller_notification); + client_start_event = nullptr; + client_next_event = nullptr; + server_start_event = nullptr; + server_next_event = nullptr; + /* Tests. */ + grpc_core::Thread thd("alts_tsi_handshaker_test", + &check_handle_response_success, nullptr); + thd.Start(); + check_handshaker_next_success(); + thd.Join(); + /* Cleanup. */ + notification_destroy(&caller_to_tsi_notification); + notification_destroy(&tsi_to_caller_notification); +} + +int main(int argc, char** argv) { + /* Initialization. */ + grpc_init(); + /* Tests. */ + check_handshaker_success(); + check_handshaker_next_invalid_input(); + check_handshaker_next_failure(); + check_handle_response_invalid_input(); + check_handle_response_invalid_resp(); + check_handle_response_failure(); + /* Cleanup. */ + grpc_shutdown(); + return 0; +} diff --git a/test/core/tsi/alts/handshaker/alts_tsi_utils_test.cc b/test/core/tsi/alts/handshaker/alts_tsi_utils_test.cc new file mode 100644 index 0000000000..98c5d23641 --- /dev/null +++ b/test/core/tsi/alts/handshaker/alts_tsi_utils_test.cc @@ -0,0 +1,73 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/alts/handshaker/alts_tsi_utils.h" +#include "test/core/tsi/alts/handshaker/alts_handshaker_service_api_test_lib.h" + +#define ALTS_TSI_UTILS_TEST_OUT_FRAME "Hello Google" + +static void convert_to_tsi_result_test() { + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_OK) == TSI_OK); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_UNKNOWN) == + TSI_UNKNOWN_ERROR); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result( + GRPC_STATUS_INVALID_ARGUMENT) == TSI_INVALID_ARGUMENT); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_OUT_OF_RANGE) == + TSI_UNKNOWN_ERROR); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_INTERNAL) == + TSI_INTERNAL_ERROR); + GPR_ASSERT(alts_tsi_utils_convert_to_tsi_result(GRPC_STATUS_NOT_FOUND) == + TSI_NOT_FOUND); +} + +static void deserialize_response_test() { + grpc_gcp_handshaker_resp* resp = grpc_gcp_handshaker_resp_create(); + GPR_ASSERT(grpc_gcp_handshaker_resp_set_out_frames( + resp, ALTS_TSI_UTILS_TEST_OUT_FRAME, + strlen(ALTS_TSI_UTILS_TEST_OUT_FRAME))); + grpc_slice slice; + GPR_ASSERT(grpc_gcp_handshaker_resp_encode(resp, &slice)); + + /* Valid serialization. */ + grpc_byte_buffer* buffer = + grpc_raw_byte_buffer_create(&slice, 1 /* number of slices */); + grpc_gcp_handshaker_resp* decoded_resp = + alts_tsi_utils_deserialize_response(buffer); + GPR_ASSERT(grpc_gcp_handshaker_resp_equals(resp, decoded_resp)); + grpc_byte_buffer_destroy(buffer); + + /* Invalid serializaiton. */ + grpc_slice bad_slice = + grpc_slice_split_head(&slice, GRPC_SLICE_LENGTH(slice) - 1); + buffer = grpc_raw_byte_buffer_create(&bad_slice, 1 /* number of slices */); + GPR_ASSERT(alts_tsi_utils_deserialize_response(buffer) == nullptr); + + /* Clean up. */ + grpc_slice_unref(slice); + grpc_slice_unref(bad_slice); + grpc_byte_buffer_destroy(buffer); + grpc_gcp_handshaker_resp_destroy(resp); + grpc_gcp_handshaker_resp_destroy(decoded_resp); +} + +int main(int argc, char** argv) { + /* Tests. */ + deserialize_response_test(); + convert_to_tsi_result_test(); + return 0; +} diff --git a/test/core/tsi/alts/handshaker/transport_security_common_api_test.cc b/test/core/tsi/alts/handshaker/transport_security_common_api_test.cc new file mode 100644 index 0000000000..6ff1357c27 --- /dev/null +++ b/test/core/tsi/alts/handshaker/transport_security_common_api_test.cc @@ -0,0 +1,196 @@ +/* + * + * Copyright 2018 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include <stdbool.h> +#include <stdio.h> +#include <stdlib.h> + +#include "src/core/tsi/alts/handshaker/transport_security_common_api.h" + +const size_t kMaxRpcVersionMajor = 3; +const size_t kMaxRpcVersionMinor = 2; +const size_t kMinRpcVersionMajor = 2; +const size_t kMinRpcVersionMinor = 1; + +static bool grpc_gcp_rpc_protocol_versions_equal( + grpc_gcp_rpc_protocol_versions* l_versions, + grpc_gcp_rpc_protocol_versions* r_versions) { + GPR_ASSERT(l_versions != nullptr && r_versions != nullptr); + if ((l_versions->has_max_rpc_version ^ r_versions->has_max_rpc_version) | + (l_versions->has_min_rpc_version ^ r_versions->has_min_rpc_version)) { + return false; + } + if (l_versions->has_max_rpc_version) { + if ((l_versions->max_rpc_version.major != + r_versions->max_rpc_version.major) || + (l_versions->max_rpc_version.minor != + r_versions->max_rpc_version.minor)) { + return false; + } + } + if (l_versions->has_min_rpc_version) { + if ((l_versions->min_rpc_version.major != + r_versions->min_rpc_version.major) || + (l_versions->min_rpc_version.minor != + r_versions->min_rpc_version.minor)) { + return false; + } + } + return true; +} + +static void test_success() { + grpc_gcp_rpc_protocol_versions version; + grpc_gcp_rpc_protocol_versions decoded_version; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max( + &version, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min( + &version, kMinRpcVersionMajor, kMinRpcVersionMinor)); + /* Serializes to raw bytes. */ + size_t encoded_length = + grpc_gcp_rpc_protocol_versions_encode_length(&version); + uint8_t* encoded_bytes = static_cast<uint8_t*>(gpr_malloc(encoded_length)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode_to_raw_bytes( + &version, encoded_bytes, encoded_length)); + grpc_slice encoded_slice; + /* Serializes to grpc slice. */ + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode(&version, &encoded_slice)); + /* Checks serialized raw bytes and serialized grpc slice have same content. */ + GPR_ASSERT(encoded_length == GRPC_SLICE_LENGTH(encoded_slice)); + GPR_ASSERT(memcmp(encoded_bytes, GRPC_SLICE_START_PTR(encoded_slice), + encoded_length) == 0); + /* Deserializes and compares with the original version. */ + GPR_ASSERT( + grpc_gcp_rpc_protocol_versions_decode(encoded_slice, &decoded_version)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_equal(&version, &decoded_version)); + grpc_slice_unref(encoded_slice); + gpr_free(encoded_bytes); +} + +static void test_failure() { + grpc_gcp_rpc_protocol_versions version, decoded_version; + grpc_slice encoded_slice; + /* Test for invalid arguments. */ + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_set_max( + nullptr, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_set_min( + nullptr, kMinRpcVersionMajor, kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode_length(nullptr) == 0); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max( + &version, kMaxRpcVersionMajor, kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min( + &version, kMinRpcVersionMajor, kMinRpcVersionMinor)); + size_t encoded_length = + grpc_gcp_rpc_protocol_versions_encode_length(&version); + uint8_t* encoded_bytes = static_cast<uint8_t*>(gpr_malloc(encoded_length)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_encode_to_raw_bytes( + nullptr, encoded_bytes, encoded_length)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_encode_to_raw_bytes( + &version, nullptr, encoded_length)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_encode_to_raw_bytes( + &version, encoded_bytes, 0)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_encode(nullptr, &encoded_slice)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_encode(&version, nullptr)); + GPR_ASSERT(!grpc_gcp_rpc_protocol_versions_decode(encoded_slice, nullptr)); + /* Test for nanopb decode. */ + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_encode(&version, &encoded_slice)); + grpc_slice bad_slice = grpc_slice_split_head( + &encoded_slice, GRPC_SLICE_LENGTH(encoded_slice) - 1); + grpc_slice_unref(encoded_slice); + GPR_ASSERT( + !grpc_gcp_rpc_protocol_versions_decode(bad_slice, &decoded_version)); + grpc_slice_unref(bad_slice); + gpr_free(encoded_bytes); +} + +static void test_copy() { + grpc_gcp_rpc_protocol_versions src; + grpc_gcp_rpc_protocol_versions des; + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&src, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&src, kMinRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_copy(&src, &des)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_equal(&src, &des)); +} + +static void test_check_success() { + grpc_gcp_rpc_protocol_versions v1; + grpc_gcp_rpc_protocol_versions v2; + grpc_gcp_rpc_protocol_versions_version highest_common_version; + /* test equality. */ + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v1, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v1, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_check( + (const grpc_gcp_rpc_protocol_versions*)&v1, + (const grpc_gcp_rpc_protocol_versions*)&v2, + &highest_common_version) == 1); + GPR_ASSERT(grpc_core::internal::grpc_gcp_rpc_protocol_version_compare( + &highest_common_version, &v1.max_rpc_version) == 0); + + /* test inequality. */ + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v1, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v1, kMinRpcVersionMinor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v2, kMaxRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v2, kMinRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_check( + (const grpc_gcp_rpc_protocol_versions*)&v1, + (const grpc_gcp_rpc_protocol_versions*)&v2, + &highest_common_version) == 1); + GPR_ASSERT(grpc_core::internal::grpc_gcp_rpc_protocol_version_compare( + &highest_common_version, &v2.max_rpc_version) == 0); +} + +static void test_check_failure() { + grpc_gcp_rpc_protocol_versions v1; + grpc_gcp_rpc_protocol_versions v2; + grpc_gcp_rpc_protocol_versions_version highest_common_version; + + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v1, kMinRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v1, kMinRpcVersionMajor, + kMinRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_max(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_set_min(&v2, kMaxRpcVersionMajor, + kMaxRpcVersionMinor)); + GPR_ASSERT(grpc_gcp_rpc_protocol_versions_check( + (const grpc_gcp_rpc_protocol_versions*)&v1, + (const grpc_gcp_rpc_protocol_versions*)&v2, + &highest_common_version) == 0); +} + +int main(int argc, char** argv) { + /* Run tests. */ + test_success(); + test_failure(); + test_copy(); + test_check_success(); + test_check_failure(); + return 0; +} |