aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-07-19 07:23:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-19 07:28:28 -0700
commit41803db36d4f4a3239bd81e5d460eb0e6e2eea88 (patch)
tree4c677032f4a6b607ec1bfa968f2a36efc01ba04d
parentac7530e54ddaa17c17e070e5a141002c43b86275 (diff)
tfdbg: open-source C++ and Python libraries of gRPC debugger mode
with the exception of Windows the session_debug_grpc_test is temporarily disabled on mac pending pip install of futures and grpcio on all slaves. PiperOrigin-RevId: 162482541
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake3
-rw-r--r--tensorflow/core/debug/BUILD42
-rw-r--r--tensorflow/core/debug/debug_grpc_io_utils_test.cc432
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc45
-rw-r--r--tensorflow/core/debug/debug_io_utils.h4
-rw-r--r--tensorflow/core/platform/default/build_config.bzl4
-rw-r--r--tensorflow/python/debug/BUILD91
-rw-r--r--tensorflow/python/debug/lib/debug_data.py3
-rwxr-xr-xtensorflow/python/debug/lib/debug_service_pb2_grpc.py76
-rw-r--r--tensorflow/python/debug/lib/dist_session_debug_grpc_test.py230
-rw-r--r--tensorflow/python/debug/lib/grpc_debug_server.py395
-rw-r--r--tensorflow/python/debug/lib/grpc_debug_test_server.py336
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py622
-rwxr-xr-xtensorflow/tools/ci_build/builds/pip.sh118
-rw-r--r--tensorflow/tools/pip_package/BUILD1
15 files changed, 2330 insertions, 72 deletions
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 089967ffec..0b73f8bb7f 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -162,6 +162,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py"
# Windows does not have the curses library and uses readline.
"${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py"
+ # TFDBG grpc:// mode is not yet available on Windows.
+ "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py"
+ "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py"
# generally not working
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py"
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index dd26177d25..971576698e 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -15,7 +15,6 @@ package(
licenses(["notice"]) # Apache 2.0
-# Google-internal rules omitted.
load(
"//tensorflow:tensorflow.bzl",
"check_deps",
@@ -28,6 +27,7 @@ load(
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_kernel_tests_linkstatic",
+ "tf_proto_library",
"tf_proto_library_cc",
)
load(
@@ -42,13 +42,25 @@ check_deps(
deps = ["//tensorflow/core:tensorflow"],
)
-tf_proto_library_cc(
+tf_proto_library(
name = "debug_service_proto",
- srcs = ["debug_service.proto"],
+ srcs = [
+ "debug_service.proto",
+ ],
has_services = 1,
cc_api_version = 2,
cc_grpc_version = 1,
- protodeps = ["//tensorflow/core:protos_all"],
+ protodeps = [
+ ":debugger_event_metadata_proto",
+ "//tensorflow/core:protos_all",
+ ],
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+tf_proto_library(
+ name = "debugger_event_metadata_proto",
+ srcs = ["debugger_event_metadata.proto"],
+ cc_api_version = 2,
)
cc_library(
@@ -196,6 +208,7 @@ tf_cc_test(
deps = [
":debug_grpc_testlib",
":debug_io_utils",
+ ":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
@@ -254,10 +267,23 @@ tf_cc_test(
],
)
-tf_proto_library_cc(
- name = "debugger_event_metadata_proto",
- srcs = ["debugger_event_metadata.proto"],
- cc_api_version = 2,
+tf_cc_test(
+ name = "debug_grpc_io_utils_test",
+ size = "small",
+ srcs = ["debug_grpc_io_utils_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":debug_graph_utils",
+ ":debug_grpc_testlib",
+ ":debug_io_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
)
# TODO(cais): Add the following back in when tfdbg is supported on Android.
diff --git a/tensorflow/core/debug/debug_grpc_io_utils_test.cc b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
new file mode 100644
index 0000000000..6510424182
--- /dev/null
+++ b/tensorflow/core/debug/debug_grpc_io_utils_test.cc
@@ -0,0 +1,432 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/debug/debug_graph_utils.h"
+#include "tensorflow/core/debug/debug_grpc_testlib.h"
+#include "tensorflow/core/debug/debug_io_utils.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/tracing.h"
+
+namespace tensorflow {
+
+class GrpcDebugTest : public ::testing::Test {
+ protected:
+ struct ServerData {
+ int port;
+ string url;
+ std::unique_ptr<test::TestEventListenerImpl> server;
+ std::unique_ptr<thread::ThreadPool> thread_pool;
+ };
+
+ void SetUp() override {
+ ClearEnabledWatchKeys();
+ SetUpInProcessServer(&server_data_, 0);
+ }
+
+ void TearDown() override { TearDownInProcessServer(&server_data_); }
+
+ void SetUpInProcessServer(ServerData* server_data,
+ int64 server_start_delay_micros) {
+ server_data->port = testing::PickUnusedPortOrDie();
+ server_data->url = strings::StrCat("grpc://localhost:", server_data->port);
+ server_data->server.reset(new test::TestEventListenerImpl());
+
+ server_data->thread_pool.reset(
+ new thread::ThreadPool(Env::Default(), "test_server", 1));
+ server_data->thread_pool->Schedule(
+ [server_data, server_start_delay_micros]() {
+ Env::Default()->SleepForMicroseconds(server_start_delay_micros);
+ server_data->server->RunServer(server_data->port);
+ });
+ }
+
+ void TearDownInProcessServer(ServerData* server_data) {
+ server_data->server->StopServer();
+ server_data->thread_pool.reset();
+ }
+
+ void ClearEnabledWatchKeys() { DebugGrpcIO::ClearEnabledWatchKeys(); }
+
+ void CreateEmptyEnabledSet(const string& grpc_debug_url) {
+ DebugGrpcIO::CreateEmptyEnabledSet(grpc_debug_url);
+ }
+
+ const int64 GetChannelConnectionTimeoutMicros() {
+ return DebugGrpcIO::channel_connection_timeout_micros;
+ }
+
+ void SetChannelConnectionTimeoutMicros(const int64 timeout) {
+ DebugGrpcIO::channel_connection_timeout_micros = timeout;
+ }
+
+ ServerData server_data_;
+};
+
+TEST_F(GrpcDebugTest, ConnectionTimeoutWorks) {
+ // Use a short timeout so the test won't take too long.
+ const int64 kOriginalTimeoutMicros = GetChannelConnectionTimeoutMicros();
+ const int64 kShortTimeoutMicros = 500 * 1000;
+ SetChannelConnectionTimeoutMicros(kShortTimeoutMicros);
+ ASSERT_EQ(kShortTimeoutMicros, GetChannelConnectionTimeoutMicros());
+
+ const string& kInvalidGrpcUrl =
+ strings::StrCat("grpc://localhost:", testing::PickUnusedPortOrDie());
+ Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
+ tensor.flat<float>()(0) = 42.0;
+ Status publish_status = DebugIO::PublishDebugTensor(
+ DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0", "foo_tensor", 0,
+ "DebugIdentity"),
+ tensor, Env::Default()->NowMicros(), {kInvalidGrpcUrl});
+ SetChannelConnectionTimeoutMicros(kOriginalTimeoutMicros);
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(kInvalidGrpcUrl));
+
+ ASSERT_FALSE(publish_status.ok());
+ const string expected_error_msg = strings::StrCat(
+ "Failed to connect to gRPC channel at ", kInvalidGrpcUrl.substr(7),
+ " within a timeout of ", kShortTimeoutMicros / 1e6, " s");
+ ASSERT_NE(string::npos,
+ publish_status.error_message().find(expected_error_msg));
+}
+
+TEST_F(GrpcDebugTest, ConnectionToDelayedStartingServerWorks) {
+ ServerData server_data;
+ // Server start will be delayed for 1 second.
+ SetUpInProcessServer(&server_data, 1 * 1000 * 1000);
+
+ Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
+ tensor.flat<float>()(0) = 42.0;
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ Status publish_status = DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data.url});
+ ASSERT_TRUE(publish_status.ok());
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data.url));
+
+ ASSERT_EQ(1, server_data.server->node_names.size());
+ ASSERT_EQ(1, server_data.server->output_slots.size());
+ ASSERT_EQ(1, server_data.server->debug_ops.size());
+ EXPECT_EQ(kDebugNodeKey.device_name, server_data.server->device_names[0]);
+ EXPECT_EQ(kDebugNodeKey.node_name, server_data.server->node_names[0]);
+ EXPECT_EQ(kDebugNodeKey.output_slot, server_data.server->output_slots[0]);
+ EXPECT_EQ(kDebugNodeKey.debug_op, server_data.server->debug_ops[0]);
+ TearDownInProcessServer(&server_data);
+}
+
+TEST_F(GrpcDebugTest, SendSingleDebugTensorViaGrpcTest) {
+ Tensor tensor(DT_FLOAT, TensorShape({1, 1}));
+ tensor.flat<float>()(0) = 42.0;
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ TF_ASSERT_OK(DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url}));
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url));
+
+ // Verify that the expected debug tensor sending happened.
+ ASSERT_EQ(1, server_data_.server->node_names.size());
+ ASSERT_EQ(1, server_data_.server->output_slots.size());
+ ASSERT_EQ(1, server_data_.server->debug_ops.size());
+ EXPECT_EQ(kDebugNodeKey.device_name, server_data_.server->device_names[0]);
+ EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
+ EXPECT_EQ(kDebugNodeKey.output_slot, server_data_.server->output_slots[0]);
+ EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
+}
+
+TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex0ViaGrpcTest) {
+ Tensor tensor(DT_STRING, TensorShape({1, 1}));
+ tensor.flat<string>()(0) = string(5000 * 1024, 'A');
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ const Status status = DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url});
+ ASSERT_FALSE(status.ok());
+ ASSERT_NE(status.error_message().find("string value at index 0 from debug "
+ "node foo_tensor:0:DebugIdentity does "
+ "not fit gRPC message size limit"),
+ string::npos);
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url));
+}
+
+TEST_F(GrpcDebugTest, SendDebugTensorWithLargeStringAtIndex1ViaGrpcTest) {
+ Tensor tensor(DT_STRING, TensorShape({1, 2}));
+ tensor.flat<string>()(0) = "A";
+ tensor.flat<string>()(1) = string(5000 * 1024, 'A');
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo_tensor", 0, "DebugIdentity");
+ const Status status = DebugIO::PublishDebugTensor(
+ kDebugNodeKey, tensor, Env::Default()->NowMicros(), {server_data_.url});
+ ASSERT_FALSE(status.ok());
+ ASSERT_NE(status.error_message().find("string value at index 1 from debug "
+ "node foo_tensor:0:DebugIdentity does "
+ "not fit gRPC message size limit"),
+ string::npos);
+ TF_ASSERT_OK(DebugIO::CloseDebugURL(server_data_.url));
+}
+
+TEST_F(GrpcDebugTest, SendMultipleDebugTensorsSynchronizedViaGrpcTest) {
+ const int32 kSends = 4;
+
+ // Prepare the tensors to sent.
+ std::vector<Tensor> tensors;
+ for (int i = 0; i < kSends; ++i) {
+ Tensor tensor(DT_INT32, TensorShape({1, 1}));
+ tensor.flat<int>()(0) = i * i;
+ tensors.push_back(tensor);
+ }
+
+ thread::ThreadPool* tp =
+ new thread::ThreadPool(Env::Default(), "grpc_debug_test", kSends);
+
+ mutex mu;
+ Notification all_done;
+ int tensor_count GUARDED_BY(mu) = 0;
+ std::vector<Status> statuses GUARDED_BY(mu);
+
+ const std::vector<string> urls({server_data_.url});
+
+ // Set up the concurrent tasks of sending Tensors via an Event stream to the
+ // server.
+ auto fn = [this, &mu, &tensor_count, &tensors, &statuses, &all_done,
+ &urls]() {
+ int this_count;
+ {
+ mutex_lock l(mu);
+ this_count = tensor_count++;
+ }
+
+ // Different concurrent tasks will send different tensors.
+ const uint64 wall_time = Env::Default()->NowMicros();
+ Status publish_status = DebugIO::PublishDebugTensor(
+ DebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ strings::StrCat("synchronized_node_", this_count), 0,
+ "DebugIdentity"),
+ tensors[this_count], wall_time, urls);
+
+ {
+ mutex_lock l(mu);
+ statuses.push_back(publish_status);
+ if (this_count == kSends - 1 && !all_done.HasBeenNotified()) {
+ all_done.Notify();
+ }
+ }
+ };
+
+ // Schedule the concurrent tasks.
+ for (int i = 0; i < kSends; ++i) {
+ tp->Schedule(fn);
+ }
+
+ // Wait for all client tasks to finish.
+ all_done.WaitForNotification();
+ delete tp;
+
+ // Close the debug gRPC stream.
+ Status close_status = DebugIO::CloseDebugURL(server_data_.url);
+ ASSERT_TRUE(close_status.ok());
+
+ // Check all statuses from the PublishDebugTensor calls().
+ for (const Status& status : statuses) {
+ TF_ASSERT_OK(status);
+ }
+
+ // One prep tensor plus kSends concurrent tensors are expected.
+ ASSERT_EQ(kSends, server_data_.server->node_names.size());
+ for (size_t i = 0; i < server_data_.server->node_names.size(); ++i) {
+ std::vector<string> items =
+ str_util::Split(server_data_.server->node_names[i], '_');
+ int tensor_index;
+ strings::safe_strto32(items[2], &tensor_index);
+
+ ASSERT_EQ(TensorShape({1, 1}),
+ server_data_.server->debug_tensors[i].shape());
+ ASSERT_EQ(tensor_index * tensor_index,
+ server_data_.server->debug_tensors[i].flat<int>()(0));
+ }
+}
+
+TEST_F(GrpcDebugTest, SendeDebugTensorsThroughMultipleRoundsUsingGrpcGating) {
+ // Prepare the tensor to send.
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "test_namescope/test_node", 0,
+ "DebugIdentity");
+ Tensor tensor(DT_INT32, TensorShape({1, 1}));
+ tensor.flat<int>()(0) = 42;
+
+ const std::vector<string> urls({server_data_.url});
+ for (int i = 0; i < 3; ++i) {
+ server_data_.server->ClearReceivedDebugData();
+ const uint64 wall_time = Env::Default()->NowMicros();
+
+ // On the 1st send (i == 0), gating is disabled, so data should be sent.
+ // On the 2nd send (i == 1), gating is enabled, and the server has enabled
+ // the watch key in the previous send, so data should be sent.
+ // On the 3rd send (i == 2), gating is enabled, but the server has disabled
+ // the watch key in the previous send, so data should not be sent.
+ const bool enable_gated_grpc = (i != 0);
+ TF_ASSERT_OK(DebugIO::PublishDebugTensor(kDebugNodeKey, tensor, wall_time,
+ urls, enable_gated_grpc));
+
+ server_data_.server->RequestDebugOpStateChangeAtNextStream(i == 0,
+ kDebugNodeKey);
+
+ // Close the debug gRPC stream.
+ Status close_status = DebugIO::CloseDebugURL(server_data_.url);
+ ASSERT_TRUE(close_status.ok());
+
+ // Check dumped files according to the expected gating results.
+ if (i < 2) {
+ ASSERT_EQ(1, server_data_.server->node_names.size());
+ ASSERT_EQ(1, server_data_.server->output_slots.size());
+ ASSERT_EQ(1, server_data_.server->debug_ops.size());
+ EXPECT_EQ(kDebugNodeKey.device_name,
+ server_data_.server->device_names[0]);
+ EXPECT_EQ(kDebugNodeKey.node_name, server_data_.server->node_names[0]);
+ EXPECT_EQ(kDebugNodeKey.output_slot,
+ server_data_.server->output_slots[0]);
+ EXPECT_EQ(kDebugNodeKey.debug_op, server_data_.server->debug_ops[0]);
+ } else {
+ ASSERT_EQ(0, server_data_.server->node_names.size());
+ }
+ }
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnEmptyEnabledSet) {
+ CreateEmptyEnabledSet("grpc://localhost:3333");
+
+ ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
+ {"grpc://localhost:3333"}));
+
+ // file:// debug URLs are not subject to grpc gating.
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen(
+ "foo:0:DebugIdentity", {"grpc://localhost:3333", "file:///tmp/tfdbg_1"}));
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSet) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kGrpcUrl2 = "grpc://localhost:3334";
+
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "bar:0:DebugIdentity");
+
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:1:DebugNumericSummary", {kGrpcUrl1}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("qux:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl1}));
+
+ // Wrong grpc:// debug URLs.
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl2}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
+
+ // file:// debug URLs are not subject to grpc gating.
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("qux:0:DebugIdentity",
+ {"file:///tmp/tfdbg_1", kGrpcUrl1}));
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnMultipleEmptyEnabledSets) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kGrpcUrl2 = "grpc://localhost:3334";
+ const string kGrpcUrl3 = "grpc://localhost:3335";
+
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, "foo:0:DebugIdentity");
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl2, "bar:0:DebugIdentity");
+ CreateEmptyEnabledSet(kGrpcUrl3);
+
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl2}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl2}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl1}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity", {kGrpcUrl3}));
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity", {kGrpcUrl3}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl2}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl2}));
+ ASSERT_TRUE(DebugIO::IsDebugNodeGateOpen("foo:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl3}));
+ ASSERT_FALSE(DebugIO::IsDebugNodeGateOpen("bar:0:DebugIdentity",
+ {kGrpcUrl1, kGrpcUrl3}));
+}
+
+TEST_F(GrpcDebugTest, TestGateDebugNodeOnNonEmptyEnabledSetAndEmptyURLs) {
+ DebugGrpcIO::EnableWatchKey("grpc://localhost:3333", "foo:0:DebugIdentity");
+
+ std::vector<string> debug_urls_1;
+ ASSERT_FALSE(
+ DebugIO::IsDebugNodeGateOpen("foo:1:DebugIdentity", debug_urls_1));
+}
+
+TEST_F(GrpcDebugTest, TestGateCopyNodeOnEmptyEnabledSet) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kWatch1 = "foo:0:DebugIdentity";
+ CreateEmptyEnabledSet(kGrpcUrl1);
+
+ ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, false)}));
+
+ // file:// debug URLs are not subject to grpc gating.
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec("foo:0:DebugIdentity", kGrpcUrl1, true),
+ DebugWatchAndURLSpec("foo:0:DebugIdentity", "file:///tmp/tfdbg_1",
+ false)}));
+}
+
+TEST_F(GrpcDebugTest, TestGateCopyNodeOnNonEmptyEnabledSet) {
+ const string kGrpcUrl1 = "grpc://localhost:3333";
+ const string kGrpcUrl2 = "grpc://localhost:3334";
+ const string kWatch1 = "foo:0:DebugIdentity";
+ const string kWatch2 = "foo:1:DebugIdentity";
+ CreateEmptyEnabledSet(kGrpcUrl1);
+ CreateEmptyEnabledSet(kGrpcUrl2);
+ DebugGrpcIO::EnableWatchKey(kGrpcUrl1, kWatch1);
+
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true)}));
+
+ ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, false)}));
+
+ ASSERT_FALSE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch2, kGrpcUrl1, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch2, kGrpcUrl1, false)}));
+
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true),
+ DebugWatchAndURLSpec(kWatch1, kGrpcUrl2, true)}));
+ ASSERT_TRUE(DebugIO::IsCopyNodeGateOpen(
+ {DebugWatchAndURLSpec(kWatch1, kGrpcUrl1, true),
+ DebugWatchAndURLSpec(kWatch2, kGrpcUrl2, true)}));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 4e185b37a8..f4208a0bbc 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -17,14 +17,12 @@ limitations under the License.
#include <vector>
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
#include "grpc++/create_channel.h"
-#endif
-
-#if defined(PLATFORM_WINDOWS)
+#else
// winsock2.h is used in grpc, so Ws2_32.lib is needed
#pragma comment(lib,"Ws2_32.lib")
-#endif
+#endif // #ifndef PLATFORM_WINDOWS
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -37,10 +35,9 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/event.pb.h"
-#define GRPC_OSS_UNIMPLEMENTED_ERROR \
- return errors::Unimplemented( \
- kGrpcURLScheme, \
- " debug URL scheme is not implemented in open source yet.")
+#define GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR \
+ return errors::Unimplemented( \
+ kGrpcURLScheme, " debug URL scheme is not implemented on Windows yet.")
namespace tensorflow {
@@ -234,7 +231,7 @@ string AppendTimestampToFilePath(const string& in, const uint64 timestamp) {
return out;
}
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
// Publishes encoded GraphDef through a gRPC debugger stream, in chunks,
// conforming to the gRPC message size limit.
Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
@@ -268,7 +265,7 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
}
return Status::OK();
}
-#endif
+#endif // #ifndef PLATFORM_WINDOWS
} // namespace
@@ -393,7 +390,7 @@ Status DebugIO::PublishDebugMetadata(
Status status;
for (const string& url : debug_urls) {
if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
Event grpc_event;
// Determine the path (if any) in the grpc:// URL, and add it as a field
@@ -411,7 +408,7 @@ Status DebugIO::PublishDebugMetadata(
status.Update(
DebugGrpcIO::SendEventProtoThroughGrpcStream(grpc_event, url));
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
} else if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
@@ -450,7 +447,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
fail_statuses.push_back(s);
}
} else if (str_util::Lowercase(url).find(kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
Status s = DebugGrpcIO::SendTensorThroughGrpcStream(
debug_node_key, tensor, wall_time_us, url, gated_grpc);
@@ -459,7 +456,7 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
fail_statuses.push_back(s);
}
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
} else {
return Status(error::UNAVAILABLE,
@@ -519,11 +516,11 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
status.Update(
DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name));
} else if (debug_url.find(kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
status.Update(PublishEncodedGraphDefInChunks(buf, device_name, now_micros,
debug_url));
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
}
}
@@ -534,7 +531,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
// static
bool DebugIO::IsCopyNodeGateOpen(
const std::vector<DebugWatchAndURLSpec>& specs) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
for (const DebugWatchAndURLSpec& spec : specs) {
if (!spec.gated_grpc || spec.url.compare(0, strlen(DebugIO::kGrpcURLScheme),
DebugIO::kGrpcURLScheme)) {
@@ -554,7 +551,7 @@ bool DebugIO::IsCopyNodeGateOpen(
// static
bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
const std::vector<string>& debug_urls) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
for (const string& debug_url : debug_urls) {
if (debug_url.compare(0, strlen(DebugIO::kGrpcURLScheme),
DebugIO::kGrpcURLScheme)) {
@@ -574,7 +571,7 @@ bool DebugIO::IsDebugNodeGateOpen(const string& watch_key,
// static
bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
const string& debug_url) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
if (debug_url.find(kGrpcURLScheme) != 0) {
return true;
} else {
@@ -588,10 +585,10 @@ bool DebugIO::IsDebugURLGateOpen(const string& watch_key,
// static
Status DebugIO::CloseDebugURL(const string& debug_url) {
if (debug_url.find(DebugIO::kGrpcURLScheme) == 0) {
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
return DebugGrpcIO::CloseGrpcStream(debug_url);
#else
- GRPC_OSS_UNIMPLEMENTED_ERROR;
+ GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
} else {
// No-op for non-gRPC URLs.
@@ -703,7 +700,7 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
}
}
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
: server_stream_addr_(server_stream_addr),
url_(strings::StrCat(DebugIO::kGrpcURLScheme, server_stream_addr)) {}
@@ -926,6 +923,6 @@ void DebugGrpcIO::CreateEmptyEnabledSet(const string& grpc_debug_url) {
}
}
-#endif // #if defined(PLATFORM_GOOGLE)
+#endif // #ifndef PLATFORM_WINDOWS
} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index 98effed425..caf9f5341d 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -221,7 +221,7 @@ class DebugFileIO {
// TODO(cais): Support grpc:// debug URLs in open source once Python grpc
// genrule becomes available. See b/23796275.
-#if defined(PLATFORM_GOOGLE)
+#ifndef PLATFORM_WINDOWS
#include "tensorflow/core/debug/debug_service.grpc.pb.h"
namespace tensorflow {
@@ -345,6 +345,6 @@ class DebugGrpcIO {
};
} // namespace tensorflow
-#endif // #if defined(PLATFORM_GOOGLE)
+#endif // #ifndef(PLATFORM_WINDOWS)
#endif // TENSORFLOW_DEBUG_IO_UTILS_H_
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 1e2d026a79..d12ad8a04f 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -75,7 +75,8 @@ def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
def tf_proto_library(name, srcs = [], has_services = None,
protodeps = [], visibility = [], testonly = 0,
cc_libs = [],
- cc_api_version = 2, go_api_version = 2,
+ cc_api_version = 2, cc_grpc_version = None,
+ go_api_version = 2,
j2objc_api_version = 1,
java_api_version = 2, py_api_version = 2,
js_api_version = 2, js_codegen = "jspb"):
@@ -84,6 +85,7 @@ def tf_proto_library(name, srcs = [], has_services = None,
name = name,
srcs = srcs,
protodeps = protodeps,
+ cc_grpc_version = cc_grpc_version,
cc_libs = cc_libs,
testonly = testonly,
visibility = visibility,
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index d7ca36b12b..1c66413c05 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -6,7 +6,7 @@
# ":debug_py": Public Python methods and classes of tfdbg.
# For API documentation, see https://www.tensorflow.org/api_docs/python/tfdbg
# For a user interface walkthrough, see https://www.tensorflow.org/programmers_guide/debugger
-# (Omitted internal-only public target)
+# ":grpc_debug_server": Server interface for grpc:// debug URLs.
package(
default_visibility = ["//tensorflow:internal"],
@@ -34,6 +34,7 @@ py_library(
":debug_data",
":debug_gradients",
":debug_utils",
+ ":grpc_debug_server",
":hooks",
":local_cli_wrapper",
],
@@ -44,6 +45,7 @@ py_library(
name = "debug_pip",
deps = [
":debug_py",
+ ":grpc_debug_test_server",
":offline_analyzer",
":session_debug_testlib",
] + if_not_windows([
@@ -577,6 +579,41 @@ py_library(
],
)
+py_library(
+ name = "debug_service_pb2_grpc",
+ srcs = ["lib/debug_service_pb2_grpc.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core/debug:debug_service_proto_py",
+ ],
+)
+
+py_library(
+ name = "grpc_debug_server",
+ srcs = ["lib/grpc_debug_server.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":debug_data",
+ ":debug_service_pb2_grpc",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
+ name = "grpc_debug_test_server",
+ srcs = ["lib/grpc_debug_test_server.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":debug_data",
+ ":debug_utils",
+ ":grpc_debug_server",
+ "//tensorflow/python:client",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+)
+
cuda_py_test(
name = "session_debug_file_test",
size = "small",
@@ -743,6 +780,58 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "session_debug_grpc_test",
+ size = "medium",
+ srcs = ["lib/session_debug_grpc_test.py"],
+ additional_deps = [
+ ":debug_data",
+ ":debug_utils",
+ ":dumping_wrapper",
+ ":grpc_debug_test_server",
+ ":grpc_wrapper",
+ ":hooks",
+ ":session_debug_testlib",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "nomac", # TODO(cais): Install of futures and grpcio on all macs.
+ "notsan",
+ ],
+)
+
+# TODO(cais): Run the test in OSS, perhaps through a sh_test.
+cuda_py_test(
+ name = "dist_session_debug_grpc_test",
+ size = "medium",
+ srcs = ["lib/dist_session_debug_grpc_test.py"],
+ additional_deps = [
+ ":debug_data",
+ ":debug_utils",
+ ":dumping_wrapper",
+ ":grpc_debug_test_server",
+ ":grpc_wrapper",
+ ":hooks",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:variables",
+ ],
+ data = ["//tensorflow/tools/dist_test/server:grpc_tensorflow_server"],
+ tags = [
+ "no_oss", # b/62956105: port conflicts + Incompatible with bazel_pip.
+ "notsan",
+ ],
+)
+
py_test(
name = "dumping_wrapper_test",
size = "small",
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index 2da386bb0f..3335657a61 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -34,6 +34,7 @@ from tensorflow.core.util import event_pb2
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile
+from tensorflow.python.util import compat
# TODO(cais): Tie these string constants in with C++?
@@ -362,7 +363,7 @@ def extract_core_metadata_from_event_proto(event):
def device_name_to_device_path(device_name):
"""Convert device name to device path."""
- device_name_items = device_name.split("/")
+ device_name_items = compat.as_text(device_name).split("/")
device_name_items = [item.replace(":", "_") for item in device_name_items]
return METADATA_FILE_PREFIX + DEVICE_TAG + ",".join(device_name_items)
diff --git a/tensorflow/python/debug/lib/debug_service_pb2_grpc.py b/tensorflow/python/debug/lib/debug_service_pb2_grpc.py
new file mode 100755
index 0000000000..98adc3284b
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_service_pb2_grpc.py
@@ -0,0 +1,76 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+#
+# Do not use pylint on generated code.
+# pylint: disable=missing-docstring,g-short-docstring-punctuation,g-no-space-after-docstring-summary,invalid-name,line-too-long,unused-argument,g-doc-args
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import grpc
+
+from tensorflow.core.debug import debug_service_pb2 as tensorflow_dot_core_dot_debug_dot_debug__service__pb2
+from tensorflow.core.util import event_pb2 as tensorflow_dot_core_dot_util_dot_event__pb2
+
+
+class EventListenerStub(object):
+ """EventListener: Receives Event protos, e.g., from debugged TensorFlow
+ runtime(s).
+ """
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.SendEvents = channel.stream_stream(
+ '/tensorflow.EventListener/SendEvents',
+ request_serializer=tensorflow_dot_core_dot_util_dot_event__pb2.Event.SerializeToString,
+ response_deserializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.FromString,
+ )
+
+
+class EventListenerServicer(object):
+ """EventListener: Receives Event protos, e.g., from debugged TensorFlow
+ runtime(s).
+ """
+
+ def SendEvents(self, request_iterator, context):
+ """Client(s) can use this RPC method to send the EventListener Event protos.
+ The Event protos can hold information such as:
+ 1) intermediate tensors from a debugged graph being executed, which can
+ be sent from DebugIdentity ops configured with grpc URLs.
+ 2) GraphDefs of partition graphs, which can be sent from special debug
+ ops that get executed immediately after the beginning of the graph
+ execution.
+ """
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details('Method not implemented!')
+ raise NotImplementedError('Method not implemented!')
+
+
+def add_EventListenerServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ 'SendEvents': grpc.stream_stream_rpc_method_handler(
+ servicer.SendEvents,
+ request_deserializer=tensorflow_dot_core_dot_util_dot_event__pb2.Event.FromString,
+ response_serializer=tensorflow_dot_core_dot_debug_dot_debug__service__pb2.EventReply.SerializeToString,
+ ),
+ }
+ generic_handler = grpc.method_handlers_generic_handler(
+ 'tensorflow.EventListener', rpc_method_handlers)
+ server.add_generic_rpc_handlers((generic_handler,))
diff --git a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
new file mode 100644
index 0000000000..1e1fbf39d4
--- /dev/null
+++ b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
@@ -0,0 +1,230 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for debugger functionalities in tf.Session with grpc:// URLs.
+
+This test focus on grpc:// debugging of distributed (gRPC) sessions.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import subprocess
+import sys
+import time
+
+import portpicker
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.debug.lib import debug_utils
+from tensorflow.python.debug.lib import grpc_debug_test_server
+from tensorflow.python.debug.wrappers import framework
+from tensorflow.python.debug.wrappers import grpc_wrapper
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+
+class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
+ """Test the debugging of distributed sessions."""
+
+ PER_PROC_GPU_MEMORY_FRACTION = 0.1
+ POLLING_INTERVAL_SEC = 0.025
+
+ @classmethod
+ def setUpClass(cls):
+ gpu_memory_fraction_opt = (
+ "--gpu_memory_fraction=%f" % cls.PER_PROC_GPU_MEMORY_FRACTION)
+
+ worker_port = portpicker.pick_unused_port()
+ cluster_spec = "worker|localhost:%d" % worker_port
+ tf_logging.info("cluster_spec: %s", cluster_spec)
+
+ server_bin = test.test_src_dir_path(
+ "tools/dist_test/server/grpc_tensorflow_server")
+
+ cls.server_target = "grpc://localhost:%d" % worker_port
+
+ cls.server_procs = {}
+ cls.server_procs["worker"] = subprocess.Popen(
+ [
+ server_bin,
+ "--cluster_spec=%s" % cluster_spec,
+ "--job_name=worker",
+ "--task_id=0",
+ gpu_memory_fraction_opt,
+ ],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+
+ # Start debug server in-process, on separate thread.
+ (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
+ cls.debug_server
+ ) = grpc_debug_test_server.start_server_on_separate_thread(
+ dump_to_filesystem=False)
+ tf_logging.info("debug server url: %s", cls.debug_server_url)
+
+ cls.session_config = config_pb2.ConfigProto(
+ gpu_options=config_pb2.GPUOptions(
+ per_process_gpu_memory_fraction=cls.PER_PROC_GPU_MEMORY_FRACTION))
+
+ @classmethod
+ def tearDownClass(cls):
+ for key in cls.server_procs:
+ cls.server_procs[key].terminate()
+ cls.debug_server.stop_server().wait()
+ cls.debug_server_thread.join()
+
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ self.debug_server.clear_data()
+
+ def _pollingAssertDebugTensorValuesAllClose(self, expected_values,
+ debug_tensor_name):
+ """Poll debug_server till tensor appears and matches expected values."""
+ while (debug_tensor_name not in self.debug_server.debug_tensor_values or
+ len(self.debug_server.debug_tensor_values) < len(expected_values)):
+ time.sleep(self.POLLING_INTERVAL_SEC)
+ self.assertAllClose(
+ expected_values,
+ self.debug_server.debug_tensor_values[debug_tensor_name])
+
+ def _createGraph(self):
+ """Create graph for testing.
+
+ Returns:
+ Python Graph object.
+ """
+ with ops.Graph().as_default() as graph:
+ with ops.device("/job:worker/task:0/cpu:0"):
+ self.a = variables.Variable(10.0, name="a")
+ self.b = variables.Variable(100.0, name="b")
+ self.inc_a = state_ops.assign_add(self.a, 2.0, name="inc_a")
+ self.dec_b = state_ops.assign_add(self.b, -5.0, name="dec_b")
+ self.p = math_ops.multiply(self.inc_a, self.dec_b, name="p")
+ self.q = math_ops.negative(self.p, name="q")
+ return graph
+
+ def testDistributedRunWithGatedGrpcCommunicatesWithDebugServerCorrectly(self):
+ graph = self._createGraph()
+ with session.Session(
+ config=self.session_config, graph=graph,
+ target=self.server_target) as sess:
+ sess.run(self.a.initializer)
+ sess.run(self.b.initializer)
+
+ run_options = config_pb2.RunOptions()
+ debug_utils.watch_graph(
+ run_options,
+ sess.graph,
+ node_name_regex_whitelist=r"a",
+ debug_ops=["DebugIdentity"],
+ debug_urls=[self.debug_server_url])
+
+ # Test gated_grpc for an op located on the worker, i.e., on the same
+ # host as where MasterSession is.
+ # TODO(cais): gRPC gating of debug ops does not work on partition graphs
+ # not located on MasterSession hosts (e.g., parameter servers) yet. Make
+ # it work.
+ debug_utils.watch_graph(
+ run_options,
+ sess.graph,
+ node_name_regex_whitelist=r"p",
+ debug_ops=["DebugIdentity(gated_grpc=True)"],
+ debug_urls=[self.debug_server_url])
+
+ for i in xrange(4):
+ # N.B.: These requests will be fulfilled not in this debugged
+ # Session.run() invocation, but in the next one.
+ if i % 2 == 0:
+ self.debug_server.request_watch("p", 0, "DebugIdentity")
+ else:
+ self.debug_server.request_unwatch("p", 0, "DebugIdentity")
+
+ expected_p = (10.0 + 2.0 * (i + 1)) * (100.0 - 5.0 * (i + 1))
+ self.assertAllClose(-expected_p, sess.run(self.q, options=run_options))
+
+ self.assertEqual(1, len(self.debug_server.core_metadata_json_strings))
+ core_metadata = json.loads(
+ self.debug_server.core_metadata_json_strings[0])
+ self.assertEqual([], core_metadata["input_names"])
+ self.assertEqual(["q:0"], core_metadata["output_names"])
+ self.assertEqual(i, core_metadata["executor_step_index"])
+
+ if i == 0:
+ self.assertEqual(1, len(self.debug_server.partition_graph_defs))
+
+ # Tensor "a" is from a PS. It may take longer to arrive due to the fact
+ # that the stream connection between the PS and the debug server is
+ # persistent and not torn down at the end of each Session.run()
+ self._pollingAssertDebugTensorValuesAllClose([10.0 + 2.0 * i],
+ "a:0:DebugIdentity")
+
+ # Due to the gRPC gating of the debug op for "p", the debug tensor
+ # should be available on odd-indexed runs.
+ if i % 2 == 0:
+ self.assertNotIn("p:0:DebugIdentity",
+ self.debug_server.debug_tensor_values)
+ else:
+ self.assertAllClose(
+ [expected_p],
+ self.debug_server.debug_tensor_values["p:0:DebugIdentity"])
+
+ self.assertNotIn("b:0:DebugIdentity",
+ self.debug_server.debug_tensor_values)
+ self.debug_server.clear_data()
+
+ def testDistributedRunWithGrpcDebugWrapperWorks(self):
+ graph = self._createGraph()
+ with session.Session(
+ config=self.session_config, graph=graph,
+ target=self.server_target) as sess:
+ sess.run(self.a.initializer)
+ sess.run(self.b.initializer)
+
+ def watch_fn(feeds, fetch_keys):
+ del feeds, fetch_keys
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"p")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+
+ for i in xrange(4):
+ expected_p = (10.0 + 2.0 * (i + 1)) * (100.0 - 5.0 * (i + 1))
+ self.assertAllClose(-expected_p, sess.run(self.q))
+
+ if i == 0:
+ self.assertEqual(1, len(self.debug_server.partition_graph_defs))
+
+ self.assertAllClose(
+ [expected_p],
+ self.debug_server.debug_tensor_values["p:0:DebugIdentity"])
+ self.assertNotIn("b:0:DebugIdentity",
+ self.debug_server.debug_tensor_values)
+ self.debug_server.clear_data()
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/python/debug/lib/grpc_debug_server.py b/tensorflow/python/debug/lib/grpc_debug_server.py
new file mode 100644
index 0000000000..181b437695
--- /dev/null
+++ b/tensorflow/python/debug/lib/grpc_debug_server.py
@@ -0,0 +1,395 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""gRPC debug server in Python."""
+# pylint: disable=g-bad-import-order
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import threading
+import time
+
+from concurrent import futures
+import grpc
+from six.moves import queue
+
+from tensorflow.core.debug import debug_service_pb2
+from tensorflow.core.framework import graph_pb2
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_service_pb2_grpc
+
+DebugWatch = collections.namedtuple("DebugWatch",
+ ["node_name", "output_slot", "debug_op"])
+
+
+def _watch_key_event_reply(to_enable, node_name, output_slot, debug_op):
+ """Make EventReply proto to represent a request to watch/unwatch a debug op.
+
+ Args:
+ to_enable: (`bool`) whether the request is to enable the watch key.
+ node_name: (`str`) name of the node.
+ output_slot: (`int`) output slot of the tensor.
+ debug_op: (`str`) the debug op attached to node_name:output_slot tensor to
+ watch or unwatch.
+
+ Returns:
+ An EventReply proto.
+ """
+ event_reply = debug_service_pb2.EventReply()
+ state_change = event_reply.debug_op_state_changes.add()
+ state_change.change = (
+ debug_service_pb2.EventReply.DebugOpStateChange.ENABLE
+ if to_enable else debug_service_pb2.EventReply.DebugOpStateChange.DISABLE)
+ state_change.node_name = node_name
+ state_change.output_slot = output_slot
+ state_change.debug_op = debug_op
+ return event_reply
+
+
+class EventListenerBaseStreamHandler(object):
+ """Per-stream handler of EventListener gRPC streams."""
+
+ def __init__(self):
+ """Constructor of EventListenerStreamHandler."""
+ raise NotImplementedError(
+ "__init__() is not implemented in the base stream handler class")
+
+ def on_core_metadata_event(self, event):
+ """Callback for core metadata.
+
+ Args:
+ event: The Event proto that carries a JSON string in its
+ `log_message.message` field.
+ """
+ raise NotImplementedError(
+ "on_core_metadata_event() is not implemented in the base servicer "
+ "class")
+
+ def on_graph_def(self, graph_def, device_name, wall_time):
+ """Callback for Event proto received through the gRPC stream.
+
+ This Event proto carries a GraphDef, encoded as bytes, in its graph_def
+ field.
+
+ Args:
+ graph_def: A GraphDef object.
+ device_name: Name of the device on which the graph was created.
+ wall_time: An epoch timestamp (in microseconds) for the graph.
+ """
+ raise NotImplementedError(
+ "on_graph_def() is not implemented in the base servicer class")
+
+ def on_value_event(self, event):
+ """Callback for Event proto received through the gRPC stream.
+
+ This Event proto carries a Tensor in its summary.value[0] field.
+
+ Args:
+ event: The Event proto from the stream to be processed.
+ """
+ raise NotImplementedError(
+ "on_value_event() is not implemented in the base servicer class")
+
+
+class EventListenerBaseServicer(debug_service_pb2_grpc.EventListenerServicer):
+ """Base Python class for gRPC debug server."""
+
+ def __init__(self, server_port, stream_handler_class):
+ """Constructor.
+
+ Args:
+ server_port: (int) Port number to bind to.
+ stream_handler_class: A class of the base class
+ `EventListenerBaseStreamHandler` that will be used to constructor
+ stream handler objects during `SendEvents` calls.
+ """
+
+ self._server_port = server_port
+ self._stream_handler_class = stream_handler_class
+
+ self._server_lock = threading.Lock()
+ self._server_started = False
+ self._stop_requested = False
+
+ self._event_reply_queue = queue.Queue()
+ self._gated_grpc_debug_watches = set()
+
+ def SendEvents(self, request_iterator, context):
+ """Implementation of the SendEvents service method.
+
+ This method receives streams of Event protos from the client, and processes
+ them in ways specified in the on_event() callback. The stream is
+ bi-directional, but currently only the client-to-server stream (i.e., the
+ stream from the debug ops to the server) is used.
+
+ Args:
+ request_iterator: The incoming stream of Event protos.
+ context: Server context.
+
+ Raises:
+ ValueError: If there are more than one core metadata events.
+
+ Yields:
+ An empty stream of responses.
+ """
+ core_metadata_count = 0
+
+ # A map from GraphDef hash to a list of received chunks.
+ graph_def_chunks = {}
+ tensor_chunks = {}
+
+ stream_handler = None
+ for event in request_iterator:
+ if not stream_handler:
+ stream_handler = self._stream_handler_class()
+
+ if event.graph_def:
+ maybe_graph_def, maybe_device_name, maybe_wall_time = (
+ self._process_encoded_graph_def_in_chunks(event, graph_def_chunks))
+ if maybe_graph_def:
+ stream_handler.on_graph_def(
+ maybe_graph_def, maybe_device_name, maybe_wall_time)
+ elif event.log_message.message:
+ core_metadata_count += 1
+ if core_metadata_count > 1:
+ raise ValueError(
+ "Expected one core metadata event; received multiple")
+ stream_handler.on_core_metadata_event(event)
+ elif event.summary and event.summary.value:
+ maybe_tensor_event = self._process_tensor_event_in_chunks(
+ event, tensor_chunks)
+ if maybe_tensor_event:
+ stream_handler.on_value_event(maybe_tensor_event)
+
+ # The server writes EventReply messages, if any.
+ while not self._event_reply_queue.empty():
+ yield self._event_reply_queue.get()
+
+ def _process_tensor_event_in_chunks(self, event, tensor_chunks):
+ """Possibly reassemble event chunks.
+
+ Due to gRPC's message size limit, a large tensor can be encapsulated in
+ multiple Event proto chunks to be sent through the debugger stream. This
+ method keeps track of the chunks that have arrived, reassemble all chunks
+ corresponding to a tensor when they have arrived and return the reassembled
+ Event proto.
+
+ Args:
+ event: The single Event proto that has arrived.
+ tensor_chunks: A dict used to keep track of the Event protos that have
+ arrived but haven't been reassembled.
+
+ Returns:
+ If all Event protos corresponding to a tensor have arrived, returns the
+ reassembled Event proto. Otherwise, return None.
+ """
+
+ value = event.summary.value[0]
+ debugger_plugin_metadata = json.loads(
+ value.metadata.plugin_data[0].content)
+ device_name = debugger_plugin_metadata["device"]
+ num_chunks = debugger_plugin_metadata["numChunks"]
+ chunk_index = debugger_plugin_metadata["chunkIndex"]
+
+ if num_chunks <= 1:
+ return event
+
+ debug_node_name = value.node_name
+ timestamp = int(event.wall_time)
+ tensor_key = "%s_%s_%d" % (device_name, debug_node_name, timestamp)
+
+ if tensor_key not in tensor_chunks:
+ tensor_chunks[tensor_key] = [None] * num_chunks
+
+ chunks = tensor_chunks[tensor_key]
+ if value.tensor.tensor_content:
+ chunks[chunk_index] = value.tensor
+ elif value.tensor.string_val:
+ chunks[chunk_index] = event
+
+ if None not in chunks:
+ if value.tensor.tensor_content:
+ event.summary.value[0].tensor.tensor_content = b"".join(
+ chunk.tensor_content for chunk in chunks)
+ del tensor_chunks[tensor_key]
+ return event
+ elif value.tensor.string_val:
+ merged_event = chunks[0]
+ for chunk in chunks[1:]:
+ merged_event.summary.value[0].tensor.string_val.extend(
+ list(chunk.summary.value[0].tensor.string_val))
+ return merged_event
+
+ def _process_encoded_graph_def_in_chunks(self,
+ event,
+ graph_def_chunks):
+ """Process an Event proto containing a chunk of encoded GraphDef.
+
+ Args:
+ event: the Event proto containing the chunk of encoded GraphDef.
+ graph_def_chunks: A dict mapping keys for GraphDefs (i.e.,
+ "<graph_def_hash>,<device_name>,<wall_time>") to a list of chunks of
+ encoded GraphDefs.
+
+ Returns:
+ If all chunks of the GraphDef have arrived,
+ return decoded GraphDef proto, device name, wall_time.
+ Otherwise,
+ return None, None, None.
+ """
+ graph_def = graph_pb2.GraphDef()
+ index_bar_0 = event.graph_def.find(b"|")
+ index_bar_1 = event.graph_def.find(b"|", index_bar_0 + 1)
+ index_bar_2 = event.graph_def.find(b"|", index_bar_1 + 1)
+ graph_def_hash_device_timestamp = event.graph_def[:index_bar_0]
+ chunk_index = int(event.graph_def[index_bar_0 + 1 : index_bar_1])
+ num_chunks = int(event.graph_def[index_bar_1 + 1 : index_bar_2])
+ if graph_def_hash_device_timestamp not in graph_def_chunks:
+ graph_def_chunks[graph_def_hash_device_timestamp] = [None] * num_chunks
+ graph_def_chunks[graph_def_hash_device_timestamp][
+ chunk_index] = event.graph_def[index_bar_2 + 1:]
+ if all(graph_def_chunks[graph_def_hash_device_timestamp]):
+ device_name = graph_def_hash_device_timestamp.split(b",")[1]
+ wall_time = int(graph_def_hash_device_timestamp.split(b",")[2])
+ graph_def.ParseFromString(
+ b"".join(graph_def_chunks[graph_def_hash_device_timestamp]))
+ del graph_def_chunks[graph_def_hash_device_timestamp]
+ self._process_graph_def(graph_def)
+ return graph_def, device_name, wall_time
+ else:
+ return None, None, None
+
+ def _process_graph_def(self, graph_def):
+ for node_def in graph_def.node:
+ if (debug_data.is_debug_node(node_def.name) and
+ node_def.attr["gated_grpc"].b):
+ node_name, output_slot, _, debug_op = (
+ debug_data.parse_debug_node_name(node_def.name))
+ self._gated_grpc_debug_watches.add(
+ DebugWatch(node_name, output_slot, debug_op))
+
+ def run_server(self):
+ """Start running the server.
+
+ Blocks until `stop_server` is invoked.
+
+ Raises:
+ ValueError: If server stop has already been requested, or if the server
+ has already started running.
+ """
+ self._server_lock.acquire()
+ try:
+ if self._stop_requested:
+ raise ValueError("Server has already stopped")
+ if self._server_started:
+ raise ValueError("Server has already started running")
+
+ self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ debug_service_pb2_grpc.add_EventListenerServicer_to_server(self,
+ self.server)
+ self.server.add_insecure_port("[::]:%d" % self._server_port)
+ self.server.start()
+ self._server_started = True
+ finally:
+ self._server_lock.release()
+
+ while not self._stop_requested:
+ time.sleep(1.0)
+
+ def stop_server(self, grace=1.0):
+ """Request server stopping.
+
+ Once stopped, server cannot be stopped or started again. This method is
+ non-blocking. Call `wait()` on the returned event to block until the server
+ has completely stopped.
+
+ Args:
+ grace: Grace period in seconds to be used when calling `server.stop()`.
+
+ Raises:
+ ValueError: If server stop has already been requested, or if the server
+ has not started running yet.
+
+ Returns:
+ A threading.Event that will be set when the server has completely stopped.
+ """
+ self._server_lock.acquire()
+ try:
+ if not self._server_started:
+ raise ValueError("Server has not started running")
+ if self._stop_requested:
+ raise ValueError("Server has already stopped")
+
+ self._stop_requested = True
+ return self.server.stop(grace=grace)
+ finally:
+ self._server_lock.release()
+
+ def request_watch(self, node_name, output_slot, debug_op):
+ """Request enabling a debug tensor watch.
+
+ This will let the server send a EventReply to the client side
+ (i.e., the debugged TensorFlow runtime process) to request adding a watch
+ key (i.e., <node_name>:<output_slot>:<debug_op>) to the list of enabled
+ watch keys. The list applies only to debug ops with the attribute
+ gated_grpc=True.
+
+ The request will take effect on the next debugged `Session.run()` call.
+
+ To disable the watch, use `request_unwatch()`.
+
+ Args:
+ node_name: (`str`) name of the node that the to-be-watched tensor belongs
+ to, e.g., "hidden/Weights".
+ output_slot: (`int`) output slot index of the tensor to watch.
+ debug_op: (`str`) name of the debug op to enable. This should not include
+ any attribute substrings.
+ """
+ self._event_reply_queue.put(
+ _watch_key_event_reply(True, node_name, output_slot, debug_op))
+
+ def request_unwatch(self, node_name, output_slot, debug_op):
+ """Request disabling a debug tensor watch.
+
+ The request will take effect on the next debugged `Session.run()` call.
+
+ This is the opposite of `request_watch()`.
+
+ Args:
+ node_name: (`str`) name of the node that the to-be-watched tensor belongs
+ to, e.g., "hidden/Weights".
+ output_slot: (`int`) output slot index of the tensor to watch.
+ debug_op: (`str`) name of the debug op to enable. This should not include
+ any attribute substrings.
+ """
+ self._event_reply_queue.put(
+ _watch_key_event_reply(False, node_name, output_slot, debug_op))
+
+ def gated_grpc_debug_watches(self):
+ """Get the list of debug watches with attribute gated_grpc=True.
+
+ Since the server receives `GraphDef` from the debugged runtime, it can only
+ return such debug watches that it has received so far.
+
+ Returns:
+ A `list` of `DebugWatch` `namedtuples` representing the debug watches with
+ gated_grpc=True. Each `namedtuple` element has the attributes:
+ `node_name` as a `str`,
+ `output_slot` as an `int`,
+ `debug_op` as a `str`.
+ """
+ return list(self._gated_grpc_debug_watches)
diff --git a/tensorflow/python/debug/lib/grpc_debug_test_server.py b/tensorflow/python/debug/lib/grpc_debug_test_server.py
new file mode 100644
index 0000000000..32751e0f29
--- /dev/null
+++ b/tensorflow/python/debug/lib/grpc_debug_test_server.py
@@ -0,0 +1,336 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""GRPC debug server for testing."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import errno
+import functools
+import hashlib
+import json
+import os
+import re
+import shutil
+import tempfile
+import threading
+import time
+
+import portpicker
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.util import event_pb2
+from tensorflow.python.client import session
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_utils
+from tensorflow.python.debug.lib import grpc_debug_server
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import variables
+
+
+def _get_dump_file_path(dump_root, device_name, debug_node_name):
+ """Get the file path of the dump file for a debug node.
+
+ Args:
+ dump_root: (str) Root dump directory.
+ device_name: (str) Name of the device that the debug node resides on.
+ debug_node_name: (str) Name of the debug node, e.g.,
+ cross_entropy/Log:0:DebugIdentity.
+
+ Returns:
+ (str) Full path of the dump file.
+ """
+
+ dump_root = os.path.join(
+ dump_root, debug_data.device_name_to_device_path(device_name))
+ if "/" in debug_node_name:
+ dump_dir = os.path.join(dump_root, os.path.dirname(debug_node_name))
+ dump_file_name = re.sub(":", "_", os.path.basename(debug_node_name))
+ else:
+ dump_dir = dump_root
+ dump_file_name = re.sub(":", "_", debug_node_name)
+
+ now_microsec = int(round(time.time() * 1000 * 1000))
+ dump_file_name += "_%d" % now_microsec
+
+ return os.path.join(dump_dir, dump_file_name)
+
+
+class EventListenerTestStreamHandler(
+ grpc_debug_server.EventListenerBaseStreamHandler):
+ """Implementation of EventListenerBaseStreamHandler that dumps to file."""
+
+ def __init__(self, dump_dir, event_listener_servicer):
+ self._dump_dir = dump_dir
+ self._event_listener_servicer = event_listener_servicer
+ if self._dump_dir:
+ self._try_makedirs(self._dump_dir)
+
+ self._grpc_path = None
+ self._cached_graph_defs = []
+ self._cached_graph_def_device_names = []
+ self._cached_graph_def_wall_times = []
+
+ def on_core_metadata_event(self, event):
+ core_metadata = json.loads(event.log_message.message)
+
+ if not self._grpc_path:
+ grpc_path = core_metadata["grpc_path"]
+ if grpc_path:
+ if grpc_path.startswith("/"):
+ grpc_path = grpc_path[1:]
+ if self._dump_dir:
+ self._dump_dir = os.path.join(self._dump_dir, grpc_path)
+
+ # Write cached graph defs to filesystem.
+ for graph_def, device_name, wall_time in zip(
+ self._cached_graph_defs,
+ self._cached_graph_def_device_names,
+ self._cached_graph_def_wall_times):
+ self._write_graph_def(graph_def, device_name, wall_time)
+
+ if self._dump_dir:
+ self._write_core_metadata_event(event)
+ else:
+ self._event_listener_servicer.core_metadata_json_strings.append(
+ event.log_message.message)
+
+ def on_graph_def(self, graph_def, device_name, wall_time):
+ """Implementation of the tensor value-carrying Event proto callback.
+
+ Args:
+ graph_def: A GraphDef object.
+ device_name: Name of the device on which the graph was created.
+ wall_time: An epoch timestamp (in microseconds) for the graph.
+ """
+ if self._dump_dir:
+ if self._grpc_path:
+ self._write_graph_def(graph_def, device_name, wall_time)
+ else:
+ self._cached_graph_defs.append(graph_def)
+ self._cached_graph_def_device_names.append(device_name)
+ self._cached_graph_def_wall_times.append(wall_time)
+ else:
+ self._event_listener_servicer.partition_graph_defs.append(graph_def)
+
+ def on_value_event(self, event):
+ """Implementation of the tensor value-carrying Event proto callback.
+
+ Writes the Event proto to the file system for testing. The path written to
+ follows the same pattern as the file:// debug URLs of tfdbg, i.e., the
+ name scope of the op becomes the directory structure under the dump root
+ directory.
+
+ Args:
+ event: The Event proto carrying a tensor value.
+ """
+ if self._dump_dir:
+ self._write_value_event(event)
+ else:
+ value = event.summary.value[0]
+ self._event_listener_servicer.debug_tensor_values[value.node_name].append(
+ debug_data.load_tensor_from_event(event))
+
+ def _try_makedirs(self, dir_path):
+ if not os.path.isdir(dir_path):
+ try:
+ os.makedirs(dir_path)
+ except OSError as error:
+ if error.errno != errno.EEXIST:
+ raise
+
+ def _write_core_metadata_event(self, event):
+ core_metadata_path = os.path.join(
+ self._dump_dir,
+ debug_data.METADATA_FILE_PREFIX + debug_data.CORE_METADATA_TAG +
+ "_%d" % event.wall_time)
+ self._try_makedirs(self._dump_dir)
+ with open(core_metadata_path, "wb") as f:
+ f.write(event.SerializeToString())
+
+ def _write_graph_def(self, graph_def, device_name, wall_time):
+ encoded_graph_def = graph_def.SerializeToString()
+ graph_hash = int(hashlib.md5(encoded_graph_def).hexdigest(), 16)
+ event = event_pb2.Event(graph_def=encoded_graph_def, wall_time=wall_time)
+ graph_file_path = os.path.join(
+ self._dump_dir,
+ debug_data.device_name_to_device_path(device_name),
+ debug_data.METADATA_FILE_PREFIX + debug_data.GRAPH_FILE_TAG +
+ debug_data.HASH_TAG + "%d_%d" % (graph_hash, wall_time))
+ self._try_makedirs(os.path.dirname(graph_file_path))
+ with open(graph_file_path, "wb") as f:
+ f.write(event.SerializeToString())
+
+ def _write_value_event(self, event):
+ value = event.summary.value[0]
+
+ # Obtain the device name from the metadata.
+ summary_metadata = event.summary.value[0].metadata
+ if not summary_metadata.plugin_data:
+ raise ValueError("The value lacks plugin data.")
+ try:
+ content = json.loads(summary_metadata.plugin_data[0].content)
+ except ValueError as err:
+ raise ValueError("Could not parse content into JSON: %r, %r" % (content,
+ err))
+ device_name = content["device"]
+
+ dump_full_path = _get_dump_file_path(
+ self._dump_dir, device_name, value.node_name)
+ self._try_makedirs(os.path.dirname(dump_full_path))
+ with open(dump_full_path, "wb") as f:
+ f.write(event.SerializeToString())
+
+
+class EventListenerTestServicer(grpc_debug_server.EventListenerBaseServicer):
+ """An implementation of EventListenerBaseServicer for testing."""
+
+ def __init__(self, server_port, dump_dir):
+ """Constructor of EventListenerTestServicer.
+
+ Args:
+ server_port: (int) The server port number.
+ dump_dir: (str) The root directory to which the data files will be
+ dumped. If empty or None, the received debug data will not be dumped
+ to the file system: they will be stored in memory instead.
+ """
+ self.core_metadata_json_strings = []
+ self.partition_graph_defs = []
+ self.debug_tensor_values = collections.defaultdict(list)
+
+ grpc_debug_server.EventListenerBaseServicer.__init__(
+ self, server_port,
+ functools.partial(EventListenerTestStreamHandler, dump_dir, self))
+
+ def clear_data(self):
+ self.core_metadata_json_strings = []
+ self.partition_graph_defs = []
+ self.debug_tensor_values = collections.defaultdict(list)
+
+
+def start_server_on_separate_thread(dump_to_filesystem=True,
+ server_start_delay_sec=0.0,
+ poll_server=False):
+ """Create a test gRPC debug server and run on a separate thread.
+
+ Args:
+ dump_to_filesystem: (bool) whether the debug server will dump debug data
+ to the filesystem.
+ server_start_delay_sec: (float) amount of time (in sec) to delay the server
+ start up for.
+ poll_server: (bool) whether the server will be polled till success on
+ startup.
+
+ Returns:
+ server_port: (int) Port on which the server runs.
+ debug_server_url: (str) grpc:// URL to the server.
+ server_dump_dir: (str) The debug server's dump directory.
+ server_thread: The server Thread object.
+ server: The `EventListenerTestServicer` object.
+
+ Raises:
+ ValueError: If polling the server process for ready state is not successful
+ within maximum polling count.
+ """
+ server_port = portpicker.pick_unused_port()
+ debug_server_url = "grpc://localhost:%d" % server_port
+
+ server_dump_dir = tempfile.mkdtemp() if dump_to_filesystem else None
+ server = EventListenerTestServicer(server_port=server_port,
+ dump_dir=server_dump_dir)
+
+ def delay_then_run_server():
+ time.sleep(server_start_delay_sec)
+ server.run_server()
+ server_thread = threading.Thread(target=delay_then_run_server)
+ server_thread.start()
+
+ if poll_server:
+ if not _poll_server_till_success(
+ 50,
+ 0.2,
+ debug_server_url,
+ server_dump_dir,
+ server,
+ gpu_memory_fraction=0.1):
+ raise ValueError(
+ "Failed to start test gRPC debug server at port %d" % server_port)
+ server.clear_data()
+ return server_port, debug_server_url, server_dump_dir, server_thread, server
+
+
+def _poll_server_till_success(max_attempts,
+ sleep_per_poll_sec,
+ debug_server_url,
+ dump_dir,
+ server,
+ gpu_memory_fraction=1.0):
+ """Poll server until success or exceeding max polling count.
+
+ Args:
+ max_attempts: (int) How many times to poll at maximum
+ sleep_per_poll_sec: (float) How many seconds to sleep for after each
+ unsuccessful poll.
+ debug_server_url: (str) gRPC URL to the debug server.
+ dump_dir: (str) Dump directory to look for files in. If None, will directly
+ check data from the server object.
+ server: The server object.
+ gpu_memory_fraction: (float) Fraction of GPU memory to be
+ allocated for the Session used in server polling.
+
+ Returns:
+ (bool) Whether the polling succeeded within max_polls attempts.
+ """
+ poll_count = 0
+
+ config = config_pb2.ConfigProto(gpu_options=config_pb2.GPUOptions(
+ per_process_gpu_memory_fraction=gpu_memory_fraction))
+ with session.Session(config=config) as sess:
+ for poll_count in range(max_attempts):
+ server.clear_data()
+ print("Polling: poll_count = %d" % poll_count)
+
+ x_init_name = "x_init_%d" % poll_count
+ x_init = constant_op.constant([42.0], shape=[1], name=x_init_name)
+ x = variables.Variable(x_init, name=x_init_name)
+
+ run_options = config_pb2.RunOptions()
+ debug_utils.add_debug_tensor_watch(
+ run_options, x_init_name, 0, debug_urls=[debug_server_url])
+ try:
+ sess.run(x.initializer, options=run_options)
+ except errors.FailedPreconditionError:
+ pass
+
+ if dump_dir:
+ if os.path.isdir(
+ dump_dir) and debug_data.DebugDumpDir(dump_dir).size > 0:
+ shutil.rmtree(dump_dir)
+ print("Poll succeeded.")
+ return True
+ else:
+ print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
+ time.sleep(sleep_per_poll_sec)
+ else:
+ if server.debug_tensor_values:
+ print("Poll succeeded.")
+ return True
+ else:
+ print("Poll failed. Sleeping for %f s" % sleep_per_poll_sec)
+ time.sleep(sleep_per_poll_sec)
+
+ return False
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
new file mode 100644
index 0000000000..f97b4debd3
--- /dev/null
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -0,0 +1,622 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for debugger functionalities in tf.Session with grpc:// URLs.
+
+This test file focuses on the grpc:// debugging of local (non-distributed)
+tf.Sessions.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_utils
+from tensorflow.python.debug.lib import grpc_debug_test_server
+from tensorflow.python.debug.lib import session_debug_testlib
+from tensorflow.python.debug.wrappers import framework
+from tensorflow.python.debug.wrappers import grpc_wrapper
+from tensorflow.python.debug.wrappers import hooks
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import monitored_session
+
+
+def no_rewrite_session_config():
+ rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True)
+ graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
+ return config_pb2.ConfigProto(graph_options=graph_options)
+
+
+class GrpcDebugServerTest(test_util.TensorFlowTestCase):
+
+ def testRepeatedRunServerRaisesException(self):
+ (_, _, _, server_thread,
+ server) = grpc_debug_test_server.start_server_on_separate_thread(
+ poll_server=True)
+ # The server is started asynchronously. It needs to be polled till its state
+ # has become started.
+
+ with self.assertRaisesRegexp(
+ ValueError, "Server has already started running"):
+ server.run_server()
+
+ server.stop_server().wait()
+ server_thread.join()
+
+ def testRepeatedStopServerRaisesException(self):
+ (_, _, _, server_thread,
+ server) = grpc_debug_test_server.start_server_on_separate_thread(
+ poll_server=True)
+ server.stop_server().wait()
+ server_thread.join()
+
+ with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
+ server.stop_server().wait()
+
+ def testRunServerAfterStopRaisesException(self):
+ (_, _, _, server_thread,
+ server) = grpc_debug_test_server.start_server_on_separate_thread(
+ poll_server=True)
+ server.stop_server().wait()
+ server_thread.join()
+
+ with self.assertRaisesRegexp(ValueError, "Server has already stopped"):
+ server.run_server()
+
+
+class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
+
+ @classmethod
+ def setUpClass(cls):
+ session_debug_testlib.SessionDebugTestBase.setUpClass()
+ (cls._server_port, cls._debug_server_url, cls._server_dump_dir,
+ cls._server_thread,
+ cls._server) = grpc_debug_test_server.start_server_on_separate_thread()
+
+ @classmethod
+ def tearDownClass(cls):
+ # Stop the test server and join the thread.
+ cls._server.stop_server().wait()
+ cls._server_thread.join()
+
+ session_debug_testlib.SessionDebugTestBase.tearDownClass()
+
+ def setUp(self):
+ # Override the dump root as the test server's dump directory.
+ self._dump_root = self._server_dump_dir
+
+ def tearDown(self):
+ if os.path.isdir(self._server_dump_dir):
+ shutil.rmtree(self._server_dump_dir)
+ session_debug_testlib.SessionDebugTestBase.tearDown(self)
+
+ def _debug_urls(self, run_number=None):
+ return ["grpc://localhost:%d" % self._server_port]
+
+ def _debug_dump_dir(self, run_number=None):
+ if run_number is None:
+ return self._dump_root
+ else:
+ return os.path.join(self._dump_root, "run_%d" % run_number)
+
+ def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException(self):
+ sess = session.Session(config=no_rewrite_session_config())
+ with self.assertRaisesRegexp(
+ TypeError, "Expected type str or list in grpc_debug_server_addresses"):
+ grpc_wrapper.GrpcDebugWrapperSession(sess, 1337)
+
+ def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self):
+ sess = session.Session(config=no_rewrite_session_config())
+ with self.assertRaisesRegexp(
+ TypeError, "Expected type str in list grpc_debug_server_addresses"):
+ grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338])
+
+ def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self):
+ sess = session.Session(config=no_rewrite_session_config())
+ with self.assertRaises(TypeError):
+ grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self._server_port, watch_fn="foo")
+
+ def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self):
+ u = variables.Variable(2.1, name="u")
+ v = variables.Variable(20.0, name="v")
+ w = math_ops.multiply(u, v, name="w")
+
+ sess = session.Session(config=no_rewrite_session_config())
+ sess.run(u.initializer)
+ sess.run(v.initializer)
+
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self._server_port)
+ w_result = sess.run(w)
+ self.assertAllClose(42.0, w_result)
+
+ dump = debug_data.DebugDumpDir(self._dump_root)
+ self.assertEqual(5, dump.size)
+ self.assertAllClose([2.1], dump.get_tensors("u", 0, "DebugIdentity"))
+ self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
+ self.assertAllClose([20.0], dump.get_tensors("v", 0, "DebugIdentity"))
+ self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
+ self.assertAllClose([42.0], dump.get_tensors("w", 0, "DebugIdentity"))
+
+ def testGrpcDebugWrapperSessionWithWatchFnWorks(self):
+ def watch_fn(feeds, fetch_keys):
+ del feeds, fetch_keys
+ return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None
+
+ u = variables.Variable(2.1, name="u")
+ v = variables.Variable(20.0, name="v")
+ w = math_ops.multiply(u, v, name="w")
+
+ sess = session.Session(config=no_rewrite_session_config())
+ sess.run(u.initializer)
+ sess.run(v.initializer)
+
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self._server_port, watch_fn=watch_fn)
+ w_result = sess.run(w)
+ self.assertAllClose(42.0, w_result)
+
+ dump = debug_data.DebugDumpDir(self._dump_root)
+ self.assertEqual(4, dump.size)
+ self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
+ self.assertEqual(
+ 14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
+ self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
+ self.assertEqual(
+ 14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
+
+ def testGrpcDebugHookWithStatelessWatchFnWorks(self):
+ # Perform some set up. Specifically, construct a simple TensorFlow graph and
+ # create a watch function for certain ops.
+ def watch_fn(feeds, fetch_keys):
+ del feeds, fetch_keys
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity", "DebugNumericSummary"],
+ node_name_regex_whitelist=r".*/read",
+ op_type_regex_whitelist=None,
+ tolerate_debug_op_creation_failures=True)
+
+ u = variables.Variable(2.1, name="u")
+ v = variables.Variable(20.0, name="v")
+ w = math_ops.multiply(u, v, name="w")
+
+ sess = session.Session(config=no_rewrite_session_config())
+ sess.run(u.initializer)
+ sess.run(v.initializer)
+
+ # Create a hook. One could use this hook with say a tflearn Estimator.
+ # However, we use a HookedSession in this test to avoid depending on the
+ # internal implementation of Estimators.
+ grpc_debug_hook = hooks.GrpcDebugHook(
+ ["localhost:%d" % self._server_port], watch_fn=watch_fn)
+ sess = monitored_session._HookedSession(sess, [grpc_debug_hook])
+
+ # Run the hooked session. This should stream tensor data to the GRPC
+ # endpoints.
+ w_result = sess.run(w)
+
+ # Verify that the hook monitored the correct tensors.
+ self.assertAllClose(42.0, w_result)
+ dump = debug_data.DebugDumpDir(self._dump_root)
+ self.assertEqual(4, dump.size)
+ self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
+ self.assertEqual(
+ 14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
+ self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
+ self.assertEqual(
+ 14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
+
+ def testConstructGrpcDebugHookWithGrpcInUrlRaisesValueError(self):
+ """Tests that the hook raises an error if the URL starts with grpc://."""
+ with self.assertRaises(ValueError):
+ hooks.GrpcDebugHook(["grpc://foo:42"])
+
+
+class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ (cls.debug_server_port, cls.debug_server_url, _, cls.debug_server_thread,
+ cls.debug_server
+ ) = grpc_debug_test_server.start_server_on_separate_thread(
+ dump_to_filesystem=False)
+ tf_logging.info("debug server url: %s", cls.debug_server_url)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.debug_server.stop_server().wait()
+ cls.debug_server_thread.join()
+
+ def tearDown(self):
+ ops.reset_default_graph()
+ self.debug_server.clear_data()
+
+ def testSendingLargeGraphDefsWorks(self):
+ with self.test_session(
+ use_gpu=True, config=no_rewrite_session_config()) as sess:
+ u = variables.Variable(42.0, name="original_u")
+ for _ in xrange(50 * 1000):
+ u = array_ops.identity(u)
+ sess.run(variables.global_variables_initializer())
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"original_u")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ self.assertAllClose(42.0, sess.run(u))
+
+ self.assertAllClose(
+ [42.0],
+ self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"])
+ self.assertEqual(2 if test.is_gpu_available() else 1,
+ len(self.debug_server.partition_graph_defs))
+ max_graph_def_size = max([
+ len(graph_def.SerializeToString())
+ for graph_def in self.debug_server.partition_graph_defs])
+ self.assertGreater(max_graph_def_size, 4 * 1024 * 1024)
+
+ def testSendingLargeFloatTensorWorks(self):
+ with self.test_session(
+ use_gpu=True, config=no_rewrite_session_config()) as sess:
+ u_init_val_array = list(xrange(1200 * 1024))
+ # Size: 4 * 1200 * 1024 = 4800k > 4M
+
+ u_init = constant_op.constant(
+ u_init_val_array, dtype=dtypes.float32, name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds # Unused by this watch_fn.
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ self.assertAllEqual(
+ u_init_val_array,
+ self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+ def testSendingStringTensorWithAlmostTooLargeStringsWorks(self):
+ with self.test_session(
+ use_gpu=True, config=no_rewrite_session_config()) as sess:
+ u_init_val = [
+ b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
+ u_init = constant_op.constant(
+ u_init_val, dtype=dtypes.string, name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ self.assertAllEqual(
+ u_init_val,
+ self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+ def testSendingLargeStringTensorWorks(self):
+ with self.test_session(
+ use_gpu=True, config=no_rewrite_session_config()) as sess:
+ strs_total_size_threshold = 5000 * 1024
+ cum_size = 0
+ u_init_val_array = []
+ while cum_size < strs_total_size_threshold:
+ strlen = np.random.randint(200)
+ u_init_val_array.append(b"A" * strlen)
+ cum_size += strlen
+
+ u_init = constant_op.constant(
+ u_init_val_array, dtype=dtypes.string, name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ self.assertAllEqual(
+ u_init_val_array,
+ self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
+
+ def testSendingEmptyFloatTensorWorks(self):
+ with self.test_session(
+ use_gpu=True, config=no_rewrite_session_config()) as sess:
+ u_init = constant_op.constant(
+ [], dtype=dtypes.float32, shape=[0], name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ u_init_value = self.debug_server.debug_tensor_values[
+ "u_init:0:DebugIdentity"][0]
+ self.assertEqual(np.float32, u_init_value.dtype)
+ self.assertEqual(0, len(u_init_value))
+
+ def testSendingEmptyStringTensorWorks(self):
+ with self.test_session(
+ use_gpu=True, config=no_rewrite_session_config()) as sess:
+ u_init = constant_op.constant(
+ [], dtype=dtypes.string, shape=[0], name="u_init")
+ u = variables.Variable(u_init, name="u")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(
+ debug_ops=["DebugIdentity"],
+ node_name_regex_whitelist=r"u_init")
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
+ sess.run(u.initializer)
+
+ u_init_value = self.debug_server.debug_tensor_values[
+ "u_init:0:DebugIdentity"][0]
+ self.assertEqual(np.object, u_init_value.dtype)
+ self.assertEqual(0, len(u_init_value))
+
+
+class SessionDebugConcurrentTest(
+ session_debug_testlib.DebugConcurrentRunCallsTest):
+
+ @classmethod
+ def setUpClass(cls):
+ session_debug_testlib.SessionDebugTestBase.setUpClass()
+ (cls._server_port, cls._debug_server_url, cls._server_dump_dir,
+ cls._server_thread,
+ cls._server) = grpc_debug_test_server.start_server_on_separate_thread()
+
+ @classmethod
+ def tearDownClass(cls):
+ # Stop the test server and join the thread.
+ cls._server.stop_server().wait()
+ cls._server_thread.join()
+ session_debug_testlib.SessionDebugTestBase.tearDownClass()
+
+ def setUp(self):
+ self._num_concurrent_runs = 3
+ self._dump_roots = []
+ for i in range(self._num_concurrent_runs):
+ self._dump_roots.append(
+ os.path.join(self._server_dump_dir, "thread%d" % i))
+
+ def tearDown(self):
+ ops.reset_default_graph()
+ if os.path.isdir(self._server_dump_dir):
+ shutil.rmtree(self._server_dump_dir)
+
+ def _get_concurrent_debug_urls(self):
+ urls = []
+ for i in range(self._num_concurrent_runs):
+ urls.append(self._debug_server_url + "/thread%d" % i)
+ return urls
+
+
+class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
+ """Test server gating of debug ops."""
+
+ @classmethod
+ def setUpClass(cls):
+ (cls._server_port_1, cls._debug_server_url_1, _, cls._server_thread_1,
+ cls._server_1) = grpc_debug_test_server.start_server_on_separate_thread(
+ dump_to_filesystem=False)
+ (cls._server_port_2, cls._debug_server_url_2, _, cls._server_thread_2,
+ cls._server_2) = grpc_debug_test_server.start_server_on_separate_thread(
+ dump_to_filesystem=False)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._server_1.stop_server().wait()
+ cls._server_thread_1.join()
+ cls._server_2.stop_server().wait()
+ cls._server_thread_2.join()
+
+ def tearDown(self):
+ ops.reset_default_graph()
+ self._server_1.clear_data()
+ self._server_2.clear_data()
+
+ def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
+ with session.Session(config=no_rewrite_session_config()) as sess:
+ v = variables.Variable(50.0, name="v")
+ delta = constant_op.constant(5.0, name="delta")
+ inc_v = state_ops.assign_add(v, delta, name="inc_v")
+
+ sess.run(v.initializer)
+
+ run_metadata = config_pb2.RunMetadata()
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ debug_utils.watch_graph(
+ run_options,
+ sess.graph,
+ debug_ops=["DebugIdentity(gated_grpc=true)",
+ "DebugNumericSummary(gated_grpc=true)"],
+ debug_urls=[self._debug_server_url_1])
+
+ for i in xrange(4):
+ self._server_1.clear_data()
+
+ # N.B.: These requests will be fulfilled not in this debugged
+ # Session.run() invocation, but in the next one.
+ if i % 2 == 0:
+ self._server_1.request_watch("delta", 0, "DebugIdentity")
+ self._server_1.request_unwatch("delta", 0, "DebugNumericSummary")
+ else:
+ self._server_1.request_unwatch("delta", 0, "DebugIdentity")
+ self._server_1.request_watch("delta", 0, "DebugNumericSummary")
+
+ sess.run(inc_v, options=run_options, run_metadata=run_metadata)
+
+ if i == 0:
+ self.assertEqual(0, len(self._server_1.debug_tensor_values))
+ else:
+ self.assertEqual(1, len(self._server_1.debug_tensor_values))
+ if i % 2 == 1:
+ self.assertAllClose(
+ [5.0],
+ self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
+ else:
+ self.assertAllClose(
+ [[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0,
+ 0.0, 1.0, 0.0]],
+ self._server_1.debug_tensor_values[
+ "delta:0:DebugNumericSummary"])
+
+ def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
+ with session.Session(config=no_rewrite_session_config()) as sess:
+ v = variables.Variable(50.0, name="v")
+ delta = constant_op.constant(5.0, name="delta")
+ inc_v = state_ops.assign_add(v, delta, name="inc_v")
+
+ sess.run(v.initializer)
+
+ run_metadata = config_pb2.RunMetadata()
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ debug_utils.watch_graph(
+ run_options,
+ sess.graph,
+ debug_ops=["DebugIdentity(gated_grpc=true)"],
+ debug_urls=[self._debug_server_url_1, self._debug_server_url_2])
+
+ for i in xrange(4):
+ self._server_1.clear_data()
+ self._server_2.clear_data()
+
+ # N.B.: These requests will be fulfilled not in this debugged
+ # Session.run() invocation, but in the next one.
+ if i % 2 == 0:
+ self._server_1.request_watch("delta", 0, "DebugIdentity")
+ self._server_2.request_watch("v", 0, "DebugIdentity")
+ else:
+ self._server_1.request_unwatch("delta", 0, "DebugIdentity")
+ self._server_2.request_unwatch("v", 0, "DebugIdentity")
+
+ sess.run(inc_v, options=run_options, run_metadata=run_metadata)
+
+ if i % 2 == 0:
+ self.assertEqual(0, len(self._server_1.debug_tensor_values))
+ self.assertEqual(0, len(self._server_2.debug_tensor_values))
+ else:
+ self.assertEqual(1, len(self._server_1.debug_tensor_values))
+ self.assertEqual(1, len(self._server_2.debug_tensor_values))
+ self.assertAllClose(
+ [5.0],
+ self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
+ self.assertAllClose(
+ [50 + 5.0 * i],
+ self._server_2.debug_tensor_values["v:0:DebugIdentity"])
+
+ def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
+ with session.Session() as sess:
+ v = variables.Variable(50.0, name="v")
+ delta = constant_op.constant(5.0, name="delta")
+ inc_v = state_ops.assign_add(v, delta, name="inc_v")
+
+ sess.run(v.initializer)
+
+ # Before any debugged runs, the server should be aware of no debug
+ # watches.
+ self.assertEqual([], self._server_1.gated_grpc_debug_watches())
+
+ run_metadata = config_pb2.RunMetadata()
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ debug_utils.add_debug_tensor_watch(
+ run_options, "delta", output_slot=0,
+ debug_ops=["DebugNumericSummary(gated_grpc=true)"],
+ debug_urls=[self._debug_server_url_1])
+ debug_utils.add_debug_tensor_watch(
+ run_options, "v", output_slot=0,
+ debug_ops=["DebugIdentity"],
+ debug_urls=[self._debug_server_url_1])
+ sess.run(inc_v, options=run_options, run_metadata=run_metadata)
+
+ # After the first run, the server should have noted the debug watches
+ # for which gated_grpc == True, but not the ones with gated_grpc == False.
+ self.assertEqual(1, len(self._server_1.gated_grpc_debug_watches()))
+ debug_watch = self._server_1.gated_grpc_debug_watches()[0]
+ self.assertEqual("delta", debug_watch.node_name)
+ self.assertEqual(0, debug_watch.output_slot)
+ self.assertEqual("DebugNumericSummary", debug_watch.debug_op)
+
+
+class DelayedDebugServerTest(test_util.TensorFlowTestCase):
+
+ def testDebuggedSessionRunWorksWithDelayedDebugServerStartup(self):
+ """Test debugged Session.run() tolerates delayed debug server startup."""
+ ops.reset_default_graph()
+
+ # Start a debug server asynchronously, with a certain amount of delay.
+ (debug_server_port, _, _, server_thread,
+ debug_server) = grpc_debug_test_server.start_server_on_separate_thread(
+ server_start_delay_sec=2.0, dump_to_filesystem=False)
+
+ with self.test_session() as sess:
+ a_init = constant_op.constant(42.0, name="a_init")
+ a = variables.Variable(a_init, name="a")
+
+ def watch_fn(fetches, feeds):
+ del fetches, feeds
+ return framework.WatchOptions(debug_ops=["DebugIdentity"])
+
+ sess = grpc_wrapper.GrpcDebugWrapperSession(
+ sess, "localhost:%d" % debug_server_port, watch_fn=watch_fn)
+ sess.run(a.initializer)
+ self.assertAllClose(
+ [42.0], debug_server.debug_tensor_values["a_init:0:DebugIdentity"])
+
+ debug_server.stop_server().wait()
+ server_thread.join()
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/tools/ci_build/builds/pip.sh b/tensorflow/tools/ci_build/builds/pip.sh
index 7cb93c1774..295601c4b8 100755
--- a/tensorflow/tools/ci_build/builds/pip.sh
+++ b/tensorflow/tools/ci_build/builds/pip.sh
@@ -242,52 +242,101 @@ if [[ $(uname) == "Linux" ]]; then
fi
fi
-# Perform installation
-echo "Installing pip whl file: ${WHL_PATH}"
-# Create virtualenv directory for install test
-VENV_DIR="${PIP_TEST_ROOT}/venv"
+create_activate_virtualenv_and_install_tensorflow() {
+ # Create and activate a virtualenv; then install tensorflow pip package in it.
+ #
+ # Usage:
+ # create_activate_virtualenv_and_install_tensorflow [--clean] \
+ # <VIRTUALENV_DIR> <TF_WHEEL_PATH>
+ #
+ # Arguments:
+ # --clean: Create a clean virtualenv, i.e., without --system-site-packages.
+ # VIRTUALENV_DIR: virtualenv directory to be created.
+ # TF_WHEEL_PATH: Path to the tensorflow wheel file to be installed in the
+ # virtualenv.
+
+ VIRTUALENV_FLAGS="--system-site-packages"
+ if [[ "$1" == "--clean" ]]; then
+ VIRTUALENV_FLAGS=""
+ shift
+ fi
+
+ VIRTUALENV_DIR="$1"
+ TF_WHEEL_PATH="$2"
+ if [[ -d "${VIRTUALENV_DIR}" ]]; then
+ if rm -rf "${VIRTUALENV_DIR}"
+ then
+ echo "Removed existing virtualenv directory: ${VIRTUALENV_DIR}"
+ else
+ die "Failed to remove existing virtualenv directory: ${VIRTUALENV_DIR}"
+ fi
+ fi
-if [[ -d "${VENV_DIR}" ]]; then
- if rm -rf "${VENV_DIR}"
+ if mkdir -p "${VIRTUALENV_DIR}"
then
- echo "Removed existing virtualenv directory: ${VENV_DIR}"
+ echo "Created virtualenv directory: ${VIRTUALENV_DIR}"
else
- die "Failed to remove existing virtualenv directory: ${VENV_DIR}"
+ die "FAILED to create virtualenv directory: ${VIRTUALENV_DIR}"
fi
-fi
-
-if mkdir -p ${VENV_DIR}
-then
- echo "Created virtualenv directory: ${VENV_DIR}"
-else
- die "FAILED to create virtualenv directory: ${VENV_DIR}"
-fi
-# Verify that virtualenv exists
-if [[ -z $(which virtualenv) ]]; then
- die "FAILED: virtualenv not available on path"
-fi
+ # Verify that virtualenv exists
+ if [[ -z $(which virtualenv) ]]; then
+ die "FAILED: virtualenv not available on path"
+ fi
-virtualenv --system-site-packages -p "${PYTHON_BIN_PATH}" "${VENV_DIR}" || \
+ virtualenv ${VIRTUALENV_FLAGS} \
+ -p "${PYTHON_BIN_PATH}" "${VIRTUALENV_DIR}" || \
die "FAILED: Unable to create virtualenv"
-source "${VENV_DIR}/bin/activate" || \
- die "FAILED: Unable to activate virtualenv"
-
+ source "${VIRTUALENV_DIR}/bin/activate" || \
+ die "FAILED: Unable to activate virtualenv in ${VIRTUALENV_DIR}"
-# Install the pip file in virtual env (plus missing dependencies)
+ # Install the pip file in virtual env.
-# Upgrade pip so it supports tags such as cp27mu, manylinux1 etc.
-echo "Upgrade pip in virtualenv"
-pip install --upgrade pip==8.1.2
+ # Upgrade pip so it supports tags such as cp27mu, manylinux1 etc.
+ echo "Upgrade pip in virtualenv"
+ pip install --upgrade pip==8.1.2
-# Force tensorflow reinstallation. Otherwise it may not get installed from
-# last build if it had the same version number as previous build.
-PIP_FLAGS="--upgrade --force-reinstall"
-pip install -v ${PIP_FLAGS} ${WHL_PATH} || \
+ # Force tensorflow reinstallation. Otherwise it may not get installed from
+ # last build if it had the same version number as previous build.
+ PIP_FLAGS="--upgrade --force-reinstall"
+ pip install -v ${PIP_FLAGS} ${WHL_PATH} || \
die "pip install (forcing to reinstall tensorflow) FAILED"
-echo "Successfully installed pip package ${WHL_PATH}"
+ echo "Successfully installed pip package ${TF_WHEEL_PATH}"
+}
+
+
+# 1. Smoke test of tensorflow install in clean virtualenv
+echo
+echo "Installing and smoke-testing pip wheel in clean virtualenv: ${WHL_PATH}"
+echo
+
+CLEAN_VENV_DIR="${PIP_TEST_ROOT}/venv_clean"
+create_activate_virtualenv_and_install_tensorflow --clean \
+ "${CLEAN_VENV_DIR}" "${WHL_PATH}"
+# cd to a temporary directory to avoid picking up Python files in the source
+# tree.
+TMP_DIR=$(mktemp -d)
+pushd "${TMP_DIR}"
+[[ $(python -c "import tensorflow as tf; print(tf.Session().run(tf.constant(42)))") == 42 ]] \
+ && echo "Smoke test of tensorflow install in clean virtualenv PASSED." \
+ || die "Smoke test of tensroflow install in clean virtualenv FAILED."
+deactivate || \
+ die "FAILED: Unable to deactivate virtualenv from ${CLEAN_VENV_DIR}"
+popd
+rm -rf "${TMP_DIR}" "${CLEAN_VENV_DIR}"
+
+# 2. Perform installation of tensorflow in "non-clean" virtualenv and tests
+# against the install.
+echo
+echo "Installing and testing pip wheel in virtualenv: ${WHL_PATH}"
+echo
+
+# Create virtualenv directory for install test
+VENV_DIR="${PIP_TEST_ROOT}/venv"
+create_activate_virtualenv_and_install_tensorflow \
+ "${CLEAN_VENV_DIR}" "${WHL_PATH}"
# Install extra pip packages required by the test-on-install
for PACKAGE in ${INSTALL_EXTRA_PIP_PACKAGES}; do
@@ -343,5 +392,4 @@ if [[ "${DO_INTEGRATION_TESTS}" == "1" ]]; then
die "Integration tests on install FAILED"
fi
-deactivate || \
- die "FAILED: Unable to deactivate virtualenv"
+deactivate || die "FAILED: Unable to deactivate virtualenv from ${VENV_DIR}"
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index f82640faf4..644cab95f8 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -171,6 +171,7 @@ sh_binary(
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/saved_model:saved_model",
"//tensorflow/python/tools:tools_pip",
+ "//tensorflow/tools/dist_test/server:grpc_tensorflow_server",
],
}) + if_mkl(["//third_party/mkl:intel_binary_blob"]),
)