aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/rpc
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-04-06 17:17:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 17:19:59 -0700
commit5e11bbacaffdf7bc4a9363301de6a0755f95e9c0 (patch)
tree48f37585cd3b01c71eaced8724be21151374264d /tensorflow/core/util/rpc
parentddf54d1c24a2b4dcfd8eb52d21dc1f393785f1e9 (diff)
Open sourcing proto/rpc ops.
PiperOrigin-RevId: 191962572
Diffstat (limited to 'tensorflow/core/util/rpc')
-rw-r--r--tensorflow/core/util/rpc/BUILD48
-rw-r--r--tensorflow/core/util/rpc/call_container.h90
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.cc53
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.h70
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry.cc44
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry.h72
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry_test.cc41
7 files changed, 418 insertions, 0 deletions
diff --git a/tensorflow/core/util/rpc/BUILD b/tensorflow/core/util/rpc/BUILD
new file mode 100644
index 0000000000..f0f161ecc0
--- /dev/null
+++ b/tensorflow/core/util/rpc/BUILD
@@ -0,0 +1,48 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "call_container",
+ hdrs = ["call_container.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "rpc_factory",
+ srcs = ["rpc_factory.cc"],
+ hdrs = ["rpc_factory.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "rpc_factory_registry",
+ srcs = ["rpc_factory_registry.cc"],
+ hdrs = ["rpc_factory_registry.h"],
+ deps = [
+ ":rpc_factory",
+ "//tensorflow/core:framework",
+ ],
+)
+
+tf_cc_test(
+ name = "rpc_factory_registry_test",
+ srcs = ["rpc_factory_registry_test.cc"],
+ deps = [
+ ":rpc_factory_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h
new file mode 100644
index 0000000000..7f36056797
--- /dev/null
+++ b/tensorflow/core/util/rpc/call_container.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
+#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
+
+#include <list>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/util/reffed_status_callback.h"
+
+namespace tensorflow {
+
+template <typename Call>
+class CallContainer {
+ public:
+ explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
+ bool try_rpc, AsyncOpKernel::DoneCallback done,
+ CancellationToken token)
+ : ctx_(ctx),
+ done_(std::move(done)),
+ token_(token),
+ fail_fast_(fail_fast),
+ try_rpc_(try_rpc) {
+ CHECK_GT(num_calls, 0);
+
+ // This will run when all RPCs are finished.
+ reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
+ ctx_->cancellation_manager()->DeregisterCallback(token_);
+ ctx_->SetStatus(s);
+ done_();
+ delete this;
+ });
+
+ // Subtract reference count from the initial creation.
+ core::ScopedUnref unref(reffed_status_callback_);
+
+ for (int i = 0; i < num_calls; ++i) {
+ // Increase the reference on the callback for each new RPC.
+ reffed_status_callback_->Ref();
+ }
+ }
+
+ std::list<Call>* calls() { return &calls_; }
+
+ void StartCancel() {
+ // Once this loop is done, can no longer assume anything is valid
+ // because "delete this" may have been immediately called.
+ // Nothing should run after this loop.
+ for (auto& call : calls_) {
+ call.StartCancel();
+ }
+ }
+
+ void Done(const Status& s, int index) {
+ if (!try_rpc_) {
+ reffed_status_callback_->UpdateStatus(s);
+ }
+ reffed_status_callback_->Unref();
+ }
+
+ private:
+ OpKernelContext* ctx_;
+ std::list<Call> calls_;
+ const AsyncOpKernel::DoneCallback done_;
+ const CancellationToken token_;
+ const bool fail_fast_;
+ const bool try_rpc_;
+
+ // Performs its own reference counting.
+ ReffedStatusCallback* reffed_status_callback_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory.cc b/tensorflow/core/util/rpc/rpc_factory.cc
new file mode 100644
index 0000000000..8530f02b6e
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+template <>
+bool GetEnvVar(const char* key, const string& default_value, string* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ } else {
+ *value = env_value;
+ }
+ return true;
+}
+
+template <>
+bool GetEnvVar(const char* key, const int64& default_value, int64* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ return true;
+ }
+ return strings::safe_strto64(env_value, value);
+}
+
+template <>
+bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ return true;
+ }
+ return strings::safe_strtou64(env_value, value);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h
new file mode 100644
index 0000000000..9bf078c0f4
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory.h
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
+#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+// Return the environment variable `key`. If the variable is not set,
+// use the default value. If it is set but could not be parsed,
+// return `false`. Otherwise set `value` and return `true`.
+template <typename T>
+bool GetEnvVar(const char* key, const T& default_value, T* value);
+
+class RPCFactory {
+ public:
+ RPCFactory() {}
+ virtual ~RPCFactory() {}
+
+ // Start a Call() to methods `method_t` at addresses `address_t` with
+ // request strings from `request_t`. Any of these may be scalar
+ // Tensors, in which case the operands are broadcasted.
+ // Upon completion of all requests, `response_t` will be populated.
+ //
+ // If `try_rpc` is `true`, then `status_message_t` and
+ // `status_code_t` will be populated as well.
+ //
+ // If `try_rpc` is `false`, then `status_message_t` and
+ // `status_code_t` are ignored (and may be nullptr). Instead, the
+ // status of any failed call will be propagated to the op.
+ //
+ // REQUIRES:
+ // - `response_t` is not null, and is a string Tensor with the same shape as
+ // `request_t`.
+ //
+ // If `try_rpc` is `true`:
+ // - `status_code_t` and `status_message_t` are not null.
+ // - `status_code_t` is an int32 Tensor with the same shape as
+ // `request_t`.
+ // - `status_message_t` is a string Tensor with the same shape as
+ // `request_t`.
+ virtual void Call(OpKernelContext* ctx, int64 num_elements,
+ const Tensor& address_t, const Tensor& method_t,
+ const Tensor& request_t, const bool try_rpc,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t,
+ AsyncOpKernel::DoneCallback done) = 0;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.cc b/tensorflow/core/util/rpc/rpc_factory_registry.cc
new file mode 100644
index 0000000000..a148b5c04d
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+
+namespace tensorflow {
+
+RPCFactoryRegistry* RPCFactoryRegistry::Global() {
+ static RPCFactoryRegistry* registry = new RPCFactoryRegistry;
+ return registry;
+}
+
+RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get(
+ const string& protocol) {
+ auto found = fns_.find(protocol);
+ if (found == fns_.end()) return nullptr;
+ return &found->second;
+}
+
+void RPCFactoryRegistry::Register(const string& protocol,
+ const RPCFactoryFn& factory_fn) {
+ auto existing = Get(protocol);
+ CHECK_EQ(existing, nullptr)
+ << "RPC factory for protocol: " << protocol << " already registered";
+ fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.h b/tensorflow/core/util/rpc/rpc_factory_registry.h
new file mode 100644
index 0000000000..2635a4012e
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry.h
@@ -0,0 +1,72 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
+#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
+
+#include <map>
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+class RPCFactoryRegistry {
+ public:
+ typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms)>
+ RPCFactoryFn;
+
+ // Returns a pointer to a global RPCFactoryRegistry object.
+ static RPCFactoryRegistry* Global();
+
+ // Returns a pointer to an function that creates an RPC factory for the given
+ // protocol.
+ RPCFactoryFn* Get(const string& protocol);
+
+ // Registers a function that creates and RPC factory for the given protocol.
+ // The function should transfer the ownership of the factory to its caller.
+ void Register(const string& protocol, const RPCFactoryFn& factory_fn);
+
+ private:
+ std::map<string, RPCFactoryFn> fns_;
+};
+
+namespace rpc_factory_registration {
+
+class RPCFactoryRegistration {
+ public:
+ RPCFactoryRegistration(const string& protocol,
+ const RPCFactoryRegistry::RPCFactoryFn& factory_fn) {
+ RPCFactoryRegistry::Global()->Register(protocol, factory_fn);
+ }
+};
+
+} // namespace rpc_factory_registration
+
+#define REGISTER_RPC_FACTORY(protocol, factory_fn) \
+ REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn)
+
+#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \
+ REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn)
+
+#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \
+ static rpc_factory_registration::RPCFactoryRegistration \
+ rpc_factory_registration_fn_##ctr(protocol, factory_fn)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry_test.cc b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc
new file mode 100644
index 0000000000..cfd0f95016
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+struct Value {
+ static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms) {
+ return nullptr;
+ }
+};
+
+REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function);
+REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function);
+} // namespace
+
+TEST(RPCFactoryRegistryTest, TestBasic) {
+ EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr);
+ auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1");
+ EXPECT_NE(factory1, nullptr);
+ auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2");
+ EXPECT_NE(factory2, nullptr);
+}
+
+} // namespace tensorflow