diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-04-06 17:17:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-06 17:19:59 -0700 |
commit | 5e11bbacaffdf7bc4a9363301de6a0755f95e9c0 (patch) | |
tree | 48f37585cd3b01c71eaced8724be21151374264d /tensorflow/core/util/rpc | |
parent | ddf54d1c24a2b4dcfd8eb52d21dc1f393785f1e9 (diff) |
Open sourcing proto/rpc ops.
PiperOrigin-RevId: 191962572
Diffstat (limited to 'tensorflow/core/util/rpc')
-rw-r--r-- | tensorflow/core/util/rpc/BUILD | 48 | ||||
-rw-r--r-- | tensorflow/core/util/rpc/call_container.h | 90 | ||||
-rw-r--r-- | tensorflow/core/util/rpc/rpc_factory.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/util/rpc/rpc_factory.h | 70 | ||||
-rw-r--r-- | tensorflow/core/util/rpc/rpc_factory_registry.cc | 44 | ||||
-rw-r--r-- | tensorflow/core/util/rpc/rpc_factory_registry.h | 72 | ||||
-rw-r--r-- | tensorflow/core/util/rpc/rpc_factory_registry_test.cc | 41 |
7 files changed, 418 insertions, 0 deletions
diff --git a/tensorflow/core/util/rpc/BUILD b/tensorflow/core/util/rpc/BUILD new file mode 100644 index 0000000000..f0f161ecc0 --- /dev/null +++ b/tensorflow/core/util/rpc/BUILD @@ -0,0 +1,48 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "call_container", + hdrs = ["call_container.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "rpc_factory", + srcs = ["rpc_factory.cc"], + hdrs = ["rpc_factory.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "rpc_factory_registry", + srcs = ["rpc_factory_registry.cc"], + hdrs = ["rpc_factory_registry.h"], + deps = [ + ":rpc_factory", + "//tensorflow/core:framework", + ], +) + +tf_cc_test( + name = "rpc_factory_registry_test", + srcs = ["rpc_factory_registry_test.cc"], + deps = [ + ":rpc_factory_registry", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h new file mode 100644 index 0000000000..7f36056797 --- /dev/null +++ b/tensorflow/core/util/rpc/call_container.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ +#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ + +#include <list> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/util/reffed_status_callback.h" + +namespace tensorflow { + +template <typename Call> +class CallContainer { + public: + explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast, + bool try_rpc, AsyncOpKernel::DoneCallback done, + CancellationToken token) + : ctx_(ctx), + done_(std::move(done)), + token_(token), + fail_fast_(fail_fast), + try_rpc_(try_rpc) { + CHECK_GT(num_calls, 0); + + // This will run when all RPCs are finished. + reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { + ctx_->cancellation_manager()->DeregisterCallback(token_); + ctx_->SetStatus(s); + done_(); + delete this; + }); + + // Subtract reference count from the initial creation. + core::ScopedUnref unref(reffed_status_callback_); + + for (int i = 0; i < num_calls; ++i) { + // Increase the reference on the callback for each new RPC. + reffed_status_callback_->Ref(); + } + } + + std::list<Call>* calls() { return &calls_; } + + void StartCancel() { + // Once this loop is done, can no longer assume anything is valid + // because "delete this" may have been immediately called. + // Nothing should run after this loop. + for (auto& call : calls_) { + call.StartCancel(); + } + } + + void Done(const Status& s, int index) { + if (!try_rpc_) { + reffed_status_callback_->UpdateStatus(s); + } + reffed_status_callback_->Unref(); + } + + private: + OpKernelContext* ctx_; + std::list<Call> calls_; + const AsyncOpKernel::DoneCallback done_; + const CancellationToken token_; + const bool fail_fast_; + const bool try_rpc_; + + // Performs its own reference counting. + ReffedStatusCallback* reffed_status_callback_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory.cc b/tensorflow/core/util/rpc/rpc_factory.cc new file mode 100644 index 0000000000..8530f02b6e --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/numbers.h" + +#include "tensorflow/core/util/rpc/rpc_factory.h" + +namespace tensorflow { + +template <> +bool GetEnvVar(const char* key, const string& default_value, string* value) { + const char* env_value = std::getenv(key); + if (!env_value || env_value[0] == '\0') { + *value = default_value; + } else { + *value = env_value; + } + return true; +} + +template <> +bool GetEnvVar(const char* key, const int64& default_value, int64* value) { + const char* env_value = std::getenv(key); + if (!env_value || env_value[0] == '\0') { + *value = default_value; + return true; + } + return strings::safe_strto64(env_value, value); +} + +template <> +bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) { + const char* env_value = std::getenv(key); + if (!env_value || env_value[0] == '\0') { + *value = default_value; + return true; + } + return strings::safe_strtou64(env_value, value); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h new file mode 100644 index 0000000000..9bf078c0f4 --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ +#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +// Return the environment variable `key`. If the variable is not set, +// use the default value. If it is set but could not be parsed, +// return `false`. Otherwise set `value` and return `true`. +template <typename T> +bool GetEnvVar(const char* key, const T& default_value, T* value); + +class RPCFactory { + public: + RPCFactory() {} + virtual ~RPCFactory() {} + + // Start a Call() to methods `method_t` at addresses `address_t` with + // request strings from `request_t`. Any of these may be scalar + // Tensors, in which case the operands are broadcasted. + // Upon completion of all requests, `response_t` will be populated. + // + // If `try_rpc` is `true`, then `status_message_t` and + // `status_code_t` will be populated as well. + // + // If `try_rpc` is `false`, then `status_message_t` and + // `status_code_t` are ignored (and may be nullptr). Instead, the + // status of any failed call will be propagated to the op. + // + // REQUIRES: + // - `response_t` is not null, and is a string Tensor with the same shape as + // `request_t`. + // + // If `try_rpc` is `true`: + // - `status_code_t` and `status_message_t` are not null. + // - `status_code_t` is an int32 Tensor with the same shape as + // `request_t`. + // - `status_message_t` is a string Tensor with the same shape as + // `request_t`. + virtual void Call(OpKernelContext* ctx, int64 num_elements, + const Tensor& address_t, const Tensor& method_t, + const Tensor& request_t, const bool try_rpc, + Tensor* response_t, Tensor* status_code_t, + Tensor* status_message_t, + AsyncOpKernel::DoneCallback done) = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.cc b/tensorflow/core/util/rpc/rpc_factory_registry.cc new file mode 100644 index 0000000000..a148b5c04d --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory_registry.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> + +#include "tensorflow/core/util/rpc/rpc_factory.h" + +#include "tensorflow/core/util/rpc/rpc_factory_registry.h" + +namespace tensorflow { + +RPCFactoryRegistry* RPCFactoryRegistry::Global() { + static RPCFactoryRegistry* registry = new RPCFactoryRegistry; + return registry; +} + +RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get( + const string& protocol) { + auto found = fns_.find(protocol); + if (found == fns_.end()) return nullptr; + return &found->second; +} + +void RPCFactoryRegistry::Register(const string& protocol, + const RPCFactoryFn& factory_fn) { + auto existing = Get(protocol); + CHECK_EQ(existing, nullptr) + << "RPC factory for protocol: " << protocol << " already registered"; + fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.h b/tensorflow/core/util/rpc/rpc_factory_registry.h new file mode 100644 index 0000000000..2635a4012e --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory_registry.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ +#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ + +#include <map> +#include <string> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/rpc/rpc_factory.h" + +namespace tensorflow { + +class RPCFactoryRegistry { + public: + typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast, + int64 timeout_in_ms)> + RPCFactoryFn; + + // Returns a pointer to a global RPCFactoryRegistry object. + static RPCFactoryRegistry* Global(); + + // Returns a pointer to an function that creates an RPC factory for the given + // protocol. + RPCFactoryFn* Get(const string& protocol); + + // Registers a function that creates and RPC factory for the given protocol. + // The function should transfer the ownership of the factory to its caller. + void Register(const string& protocol, const RPCFactoryFn& factory_fn); + + private: + std::map<string, RPCFactoryFn> fns_; +}; + +namespace rpc_factory_registration { + +class RPCFactoryRegistration { + public: + RPCFactoryRegistration(const string& protocol, + const RPCFactoryRegistry::RPCFactoryFn& factory_fn) { + RPCFactoryRegistry::Global()->Register(protocol, factory_fn); + } +}; + +} // namespace rpc_factory_registration + +#define REGISTER_RPC_FACTORY(protocol, factory_fn) \ + REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn) + +#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \ + REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) + +#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \ + static rpc_factory_registration::RPCFactoryRegistration \ + rpc_factory_registration_fn_##ctr(protocol, factory_fn) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory_registry_test.cc b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc new file mode 100644 index 0000000000..cfd0f95016 --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/rpc/rpc_factory_registry.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +struct Value { + static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast, + int64 timeout_in_ms) { + return nullptr; + } +}; + +REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function); +REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function); +} // namespace + +TEST(RPCFactoryRegistryTest, TestBasic) { + EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr); + auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1"); + EXPECT_NE(factory1, nullptr); + auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2"); + EXPECT_NE(factory2, nullptr); +} + +} // namespace tensorflow |