aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-07 09:57:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-07 10:05:48 -0800
commitdb2a81a82c1753154648309f70f6ed3fa8f60837 (patch)
tree078e59f14062382856ac03f22d4fb7646c5e69f5
parent3b99c7947a01cd85a6b551daff974f122ff5b1c8 (diff)
Add a :debug BUILD target which, when linked into a binary, enables
DirectSession support for TensorFlow Debugger (tfdbg). Binaries that do not want debugging support can avoid this dependency and its transitive deps. This replaces the previous approach that was based on a preprocessor flag (-DNOTFDBG). Change: 141321165
-rw-r--r--tensorflow/contrib/cmake/tf_core_direct_session.cmake4
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake1
-rw-r--r--tensorflow/contrib/makefile/Makefile2
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt1
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt1
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt1
-rw-r--r--tensorflow/core/BUILD48
-rw-r--r--tensorflow/core/common_runtime/debugger_state_interface.cc37
-rw-r--r--tensorflow/core/common_runtime/debugger_state_interface.h72
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc11
-rw-r--r--tensorflow/core/common_runtime/direct_session.h12
-rw-r--r--tensorflow/core/debug/BUILD17
-rw-r--r--tensorflow/core/debug/debug.cc40
-rw-r--r--tensorflow/core/debug/debug_gateway_test.cc18
-rw-r--r--tensorflow/core/debug/debug_graph_utils.cc6
-rw-r--r--tensorflow/core/debug/debug_graph_utils.h12
-rw-r--r--tensorflow/core/protobuf/config.proto33
-rw-r--r--tensorflow/core/protobuf/debug.proto37
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py2
-rw-r--r--tensorflow/python/debug/debug_data.py4
-rw-r--r--tensorflow/python/debug/debug_utils.py2
-rw-r--r--tensorflow/python/debug/debug_utils_test.py67
-rw-r--r--tensorflow/python/debug/wrappers/hooks.py2
-rw-r--r--tensorflow/python/training/monitored_session.py4
-rw-r--r--tensorflow/python/training/monitored_session_test.py21
27 files changed, 341 insertions, 116 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_direct_session.cmake b/tensorflow/contrib/cmake/tf_core_direct_session.cmake
index 0fd1287197..712f04ddc1 100644
--- a/tensorflow/contrib/cmake/tf_core_direct_session.cmake
+++ b/tensorflow/contrib/cmake/tf_core_direct_session.cmake
@@ -15,6 +15,4 @@ list(REMOVE_ITEM tf_core_direct_session_srcs ${tf_core_direct_session_test_srcs}
add_library(tf_core_direct_session OBJECT ${tf_core_direct_session_srcs})
-add_dependencies(tf_core_direct_session tf_core_cpu)
-
-add_definitions(-DNOTFDBG) \ No newline at end of file
+add_dependencies(tf_core_direct_session tf_core_cpu) \ No newline at end of file
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 86d91e6876..e903471f36 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -103,6 +103,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/versions.proto"
"tensorflow/core/lib/core/error_codes.proto"
"tensorflow/core/protobuf/config.proto"
+ "tensorflow/core/protobuf/debug.proto"
"tensorflow/core/protobuf/tensor_bundle.proto"
"tensorflow/core/protobuf/saver.proto"
"tensorflow/core/util/memmapped_file_system.proto"
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index 4c0d32f115..2bf246bdf9 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -138,7 +138,7 @@ $(shell mkdir -p $(DEPDIR) >/dev/null)
# Settings for the target compiler.
CXX := $(CC_PREFIX) gcc
OPTFLAGS := -O2
-CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG -DNOTFDBG $(OPTFLAGS)
+CXXFLAGS := --std=c++11 -DIS_SLIM_BUILD -fno-exceptions -DNDEBUG $(OPTFLAGS)
LDFLAGS := \
-L/usr/local/lib
DEPFLAGS = -MT $@ -MMD -MP -MF $(DEPDIR)/$*.Td
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 8572b0bfea..395b8bde60 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/meta_graph.pb.cc
tensorflow/core/protobuf/config.pb.cc
+tensorflow/core/protobuf/debug.pb.cc
tensorflow/core/lib/core/error_codes.pb.cc
tensorflow/core/framework/versions.pb.cc
tensorflow/core/framework/variable.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index a5f9079f54..4bd371f4dc 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/meta_graph.pb.h
tensorflow/core/protobuf/config.pb.h
+tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/tensor_bundle.pb.h
tensorflow/core/lib/core/error_codes.pb.h
tensorflow/core/framework/versions.pb.h
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index a99c8b2be7..f1657793b2 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -2,6 +2,7 @@ tensorflow/core/util/saved_tensor_slice.pb_text.cc
tensorflow/core/util/memmapped_file_system.pb_text.cc
tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
+tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/tensor_bundle.pb_text.cc
tensorflow/core/lib/core/error_codes.pb_text.cc
tensorflow/core/framework/versions.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 5b53a58876..f60e7f2360 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -8,6 +8,7 @@ tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/protobuf/config.proto
+tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/tensor_bundle.proto
tensorflow/core/lib/core/error_codes.proto
tensorflow/core/framework/versions.proto
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 9d00a248df..81a517c506 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -135,6 +135,7 @@ CORE_PROTO_SRCS = [
"framework/versions.proto",
"lib/core/error_codes.proto",
"protobuf/config.proto",
+ "protobuf/debug.proto",
"protobuf/tensor_bundle.proto",
"protobuf/saver.proto",
"util/memmapped_file_system.proto",
@@ -768,7 +769,6 @@ cc_library(
srcs = if_android(["//tensorflow/core:android_srcs"]),
copts = tf_copts() + [
"-Os",
- "-DNOTFDBG",
],
linkopts = ["-lz"],
tags = [
@@ -795,7 +795,7 @@ cc_library(
"//tensorflow/core/kernels:android_core_ops",
"//tensorflow/core/kernels:android_extended_ops",
]),
- copts = tf_copts() + ["-Os"] + ["-std=c++11"] + ["-DNOTFDBG"],
+ copts = tf_copts() + ["-Os"] + ["-std=c++11"],
visibility = ["//visibility:public"],
deps = [
":protos_cc",
@@ -828,7 +828,7 @@ cc_library(
cc_library(
name = "android_tensorflow_lib",
srcs = if_android([":android_op_registrations_and_gradients"]),
- copts = tf_copts() + ["-DNOTFDBG"],
+ copts = tf_copts(),
tags = [
"manual",
"notap",
@@ -853,7 +853,6 @@ cc_library(
copts = tf_copts() + [
"-Os",
"-DSUPPORT_SELECTIVE_REGISTRATION",
- "-DNOTFDBG",
],
tags = [
"manual",
@@ -1366,7 +1365,6 @@ tf_cuda_library(
":lib_internal",
":proto_text",
":protos_all_cc",
- "//tensorflow/core/debug:debug_graph_utils",
],
alwayslink = 1,
)
@@ -1937,6 +1935,46 @@ tf_cc_test(
],
)
+# This is identical to :common_runtime_direct_session_test with the addition of
+# a dependency on alwayslink target //third_party/tensorflow/core/debug, which
+# enables support for TensorFlow Debugger (tfdbg).
+tf_cc_test(
+ name = "common_runtime_direct_session_with_debug_test",
+ size = "small",
+ srcs = ["common_runtime/direct_session_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":core",
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session_internal",
+ ":framework",
+ ":framework_internal",
+ ":lib",
+ ":lib_internal",
+ ":ops",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//third_party/eigen3",
+ "//tensorflow/cc:cc_ops",
+ # Link with support for TensorFlow Debugger (tfdbg).
+ "//tensorflow/core/debug",
+ "//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core/kernels:dense_update_ops",
+ "//tensorflow/core/kernels:fifo_queue_op",
+ "//tensorflow/core/kernels:function_ops",
+ "//tensorflow/core/kernels:identity_op",
+ "//tensorflow/core/kernels:matmul_op",
+ "//tensorflow/core/kernels:ops_util",
+ "//tensorflow/core/kernels:queue_ops",
+ "//tensorflow/core/kernels:session_ops",
+ "//tensorflow/core/kernels:variable_ops",
+ ],
+)
+
tf_cc_test(
name = "common_runtime_direct_session_with_tracking_alloc_test",
size = "small",
diff --git a/tensorflow/core/common_runtime/debugger_state_interface.cc b/tensorflow/core/common_runtime/debugger_state_interface.cc
new file mode 100644
index 0000000000..2e2fbcd7f4
--- /dev/null
+++ b/tensorflow/core/common_runtime/debugger_state_interface.cc
@@ -0,0 +1,37 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
+
+namespace tensorflow {
+
+DebuggerStateFactory* DebuggerStateRegistry::factory_ = nullptr;
+
+// static
+void DebuggerStateRegistry::RegisterFactory(
+ const DebuggerStateFactory& factory) {
+ delete factory_;
+ factory_ = new DebuggerStateFactory(factory);
+}
+
+// static
+std::unique_ptr<DebuggerStateInterface> DebuggerStateRegistry::CreateState(
+ const DebugOptions& debug_options) {
+ return (factory_ == nullptr || *factory_ == nullptr)
+ ? nullptr
+ : (*factory_)(debug_options);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/debugger_state_interface.h b/tensorflow/core/common_runtime/debugger_state_interface.h
new file mode 100644
index 0000000000..5db9ba20f3
--- /dev/null
+++ b/tensorflow/core/common_runtime/debugger_state_interface.h
@@ -0,0 +1,72 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+#define TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+
+#include <memory>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+class DebugOptions; // Defined in core/protobuf/debug.h.
+class Device;
+class Graph;
+
+// An abstract interface for storing and retrieving debugging information.
+class DebuggerStateInterface {
+ public:
+ virtual ~DebuggerStateInterface() {}
+
+ // Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
+ virtual const string SummarizeDebugTensorWatches() = 0;
+
+ // Insert special-purpose debug nodes to graph. See the documentation of
+ // DebugNodeInserter::InsertNodes() for details.
+ virtual Status InsertNodes(Graph* graph, Device* device) = 0;
+};
+
+typedef std::function<std::unique_ptr<DebuggerStateInterface>(
+ const DebugOptions& options)>
+ DebuggerStateFactory;
+
+// Contains only static methods for registering DebuggerStateFactory.
+// We don't expect to create any instances of this class.
+// Call DebuggerStateRegistry::RegisterFactory() at initialization time to
+// define a global factory that creates instances of DebuggerState, then call
+// DebuggerStateRegistry::CreateState() to create a single instance.
+class DebuggerStateRegistry {
+ public:
+ // Registers a function that creates a concrete DebuggerStateInterface
+ // implementation based on DebugOptions.
+ static void RegisterFactory(const DebuggerStateFactory& factory);
+
+ // If RegisterFactory() has been called, creates and returns a concrete
+ // DebuggerStateInterface implementation using the registered factory,
+ // owned by the caller. Otherwise returns nullptr.
+ static std::unique_ptr<DebuggerStateInterface> CreateState(
+ const DebugOptions& debug_options);
+
+ private:
+ static DebuggerStateFactory* factory_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DebuggerStateRegistry);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index e1f2c55230..6052d57239 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/constant_folding.h"
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -395,12 +396,10 @@ Status DirectSession::Run(const RunOptions& run_options,
ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args;
-#ifndef NOTFDBG
// EXPERIMENTAL: Options that allow the client to insert nodes into partition
// graphs for debugging.
- run_state_args.debugger_state.reset(
- new DebuggerState(run_options.debug_tensor_watch_opts()));
-#endif
+ run_state_args.debugger_state =
+ DebuggerStateRegistry::CreateState(run_options.debug_options());
TF_RETURN_IF_ERROR(
GetOrCreateExecutors(pool, input_tensor_names, output_names, target_nodes,
@@ -880,12 +879,10 @@ Status DirectSession::GetOrCreateExecutors(
std::sort(tn_sorted.begin(), tn_sorted.end());
string debug_tensor_watches_summary;
-#ifndef NOTFDBG
if (run_state_args->debugger_state) {
debug_tensor_watches_summary =
run_state_args->debugger_state->SummarizeDebugTensorWatches();
}
-#endif
const string key = strings::StrCat(
str_util::Join(inputs_sorted, ","), "->",
str_util::Join(outputs_sorted, ","), "/", str_util::Join(tn_sorted, ","),
@@ -985,12 +982,10 @@ Status DirectSession::GetOrCreateExecutors(
optimizer.Optimize(lib, options_.env, device, &partition_graph);
// EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph
-#ifndef NOTFDBG
if (run_state_args->debugger_state) {
TF_RETURN_IF_ERROR(run_state_args->debugger_state->InsertNodes(
partition_graph, params.device));
}
-#endif
iter->second.reset(partition_graph);
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 127c08d0a4..c9a693c556 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -24,15 +24,13 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/costmodel_manager.h"
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/common_runtime/session_factory.h"
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
-#ifndef NOTFDBG
-#include "tensorflow/core/debug/debug_graph_utils.h"
-#endif
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/session_state.h"
@@ -48,9 +46,7 @@ limitations under the License.
namespace tensorflow {
class CostModel;
-#ifndef NOTFDBG
class DebugGateway;
-#endif
class Device;
class DirectSessionFactory;
@@ -164,9 +160,7 @@ class DirectSession : public Session {
bool is_partial_run = false;
string handle;
std::unique_ptr<Graph> graph;
-#ifndef NOTFDBG
- std::unique_ptr<DebuggerState> debugger_state;
-#endif
+ std::unique_ptr<DebuggerStateInterface> debugger_state;
};
// Initializes the base execution state given the 'graph',
@@ -303,10 +297,8 @@ class DirectSession : public Session {
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
-#ifndef NOTFDBG
// EXPERIMENTAL: debugger (tfdbg) related
friend class DebugGateway;
-#endif
};
} // end namespace tensorflow
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 04e146be85..2363b69390 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -39,6 +39,21 @@ tf_proto_library_cc(
cc_libs = ["//tensorflow/core:protos_all_cc"],
)
+# Depending on this target causes a concrete DebuggerState implementation
+# to be registered at initialization time. For details, please see
+# core/common_runtime/debugger_state_interface.h.
+cc_library(
+ name = "debug",
+ srcs = ["debug.cc"],
+ copts = tf_copts(),
+ linkstatic = 1,
+ deps = [
+ ":debug_graph_utils",
+ "//tensorflow/core:core_cpu_internal",
+ ],
+ alwayslink = 1,
+)
+
tf_cuda_library(
name = "debug_gateway_internal",
srcs = ["debug_gateway.cc"],
@@ -46,6 +61,7 @@ tf_cuda_library(
copts = tf_copts(),
linkstatic = 1,
deps = [
+ ":debug",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:direct_session_internal",
"//tensorflow/core:framework",
@@ -136,6 +152,7 @@ tf_cc_test_gpu(
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
+ ":debug",
":debug_gateway_internal",
":debug_graph_utils",
"//tensorflow/cc:cc_ops",
diff --git a/tensorflow/core/debug/debug.cc b/tensorflow/core/debug/debug.cc
new file mode 100644
index 0000000000..c293b285c3
--- /dev/null
+++ b/tensorflow/core/debug/debug.cc
@@ -0,0 +1,40 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
+#include "tensorflow/core/debug/debug_graph_utils.h"
+
+namespace tensorflow {
+namespace {
+
+// Registers a concrete implementation of DebuggerState for use by
+// DirectSession.
+class DebuggerStateRegistration {
+ public:
+ static std::unique_ptr<DebuggerStateInterface> CreateDebuggerState(
+ const DebugOptions& options) {
+ return std::unique_ptr<DebuggerStateInterface>(new DebuggerState(options));
+ }
+
+ DebuggerStateRegistration() {
+ DebuggerStateRegistry::RegisterFactory(CreateDebuggerState);
+ }
+};
+static DebuggerStateRegistration register_debugger_state_implementation;
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/debug/debug_gateway_test.cc b/tensorflow/core/debug/debug_gateway_test.cc
index 1fab9a56a3..1f6e766663 100644
--- a/tensorflow/core/debug/debug_gateway_test.cc
+++ b/tensorflow/core/debug/debug_gateway_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <unordered_map>
+#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/lib/core/notification.h"
@@ -228,7 +229,8 @@ TEST_F(SessionDebugMinusAXTest, RunSimpleNetworkWithTwoDebugNodesInserted) {
const string debug_identity = "DebugIdentity";
const string debug_nan_count = "DebugNanCount";
- DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
+ DebugTensorWatch* tensor_watch_opts =
+ run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
tensor_watch_opts->set_node_name(y_);
tensor_watch_opts->set_output_slot(0);
tensor_watch_opts->add_debug_ops(debug_identity);
@@ -409,7 +411,7 @@ TEST_F(SessionDebugMinusAXTest,
run_opts.set_output_partition_graphs(true);
DebugTensorWatch* tensor_watch_opts =
- run_opts.add_debug_tensor_watch_opts();
+ run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
tensor_watch_opts->set_output_slot(0);
tensor_watch_opts->add_debug_ops(debug_identity);
@@ -561,7 +563,8 @@ TEST_F(SessionDebugOutputSlotWithoutOngoingEdgeTest,
RunOptions run_opts;
run_opts.set_output_partition_graphs(true);
- DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
+ DebugTensorWatch* tensor_watch_opts =
+ run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
tensor_watch_opts->set_node_name(c_);
tensor_watch_opts->set_output_slot(0);
tensor_watch_opts->add_debug_ops("DebugIdentity");
@@ -659,7 +662,8 @@ TEST_F(SessionDebugVariableTest, WatchUninitializedVariableWithDebugOps) {
// Set up DebugTensorWatch for an uninitialized tensor (in node var).
RunOptions run_opts;
const string debug_identity = "DebugIdentity";
- DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
+ DebugTensorWatch* tensor_watch_opts =
+ run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
tensor_watch_opts->set_node_name(var_node_name_);
tensor_watch_opts->set_output_slot(0);
tensor_watch_opts->add_debug_ops(debug_identity);
@@ -746,7 +750,8 @@ TEST_F(SessionDebugVariableTest, VariableAssignWithDebugOps) {
run_opts.set_output_partition_graphs(true);
const string debug_identity = "DebugIdentity";
const string debug_nan_count = "DebugNanCount";
- DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
+ DebugTensorWatch* tensor_watch_opts =
+ run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
tensor_watch_opts->set_node_name(var_node_name_);
tensor_watch_opts->set_output_slot(0);
tensor_watch_opts->add_debug_ops(debug_identity);
@@ -904,7 +909,8 @@ TEST_F(SessionDebugGPUSwitchTest, RunSwitchWithHostMemoryDebugOp) {
const string watched_tensor = strings::StrCat(pred_node_name_, "/_1");
const string debug_identity = "DebugIdentity";
- DebugTensorWatch* tensor_watch_opts = run_opts.add_debug_tensor_watch_opts();
+ DebugTensorWatch* tensor_watch_opts =
+ run_opts.mutable_debug_options()->add_debug_tensor_watch_opts();
tensor_watch_opts->set_node_name(watched_tensor);
tensor_watch_opts->set_output_slot(0);
tensor_watch_opts->add_debug_ops(debug_identity);
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index ca5f9b74f2..854d6c9050 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -22,12 +22,12 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/protobuf/debug.pb.h"
namespace tensorflow {
-DebuggerState::DebuggerState(
- const protobuf::RepeatedPtrField<DebugTensorWatch>& watches)
- : watches(watches), debug_urls_() {
+DebuggerState::DebuggerState(const DebugOptions& debug_options)
+ : watches(debug_options.debug_tensor_watch_opts()), debug_urls_() {
for (const DebugTensorWatch& watch : watches) {
for (const string& url : watch.debug_urls()) {
debug_urls_.insert(url);
diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h
index 1b37ddf9d8..79e56f35a1 100644
--- a/tensorflow/core/debug/debug_graph_utils.h
+++ b/tensorflow/core/debug/debug_graph_utils.h
@@ -20,26 +20,26 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"
-#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/debug.pb.h"
namespace tensorflow {
-class DebuggerState {
+class DebuggerState : public DebuggerStateInterface {
public:
- DebuggerState(
- const protobuf::RepeatedPtrField<DebugTensorWatch>& debug_tensor_watches);
+ DebuggerState(const DebugOptions& debug_options);
virtual ~DebuggerState();
// Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
- const string SummarizeDebugTensorWatches();
+ const string SummarizeDebugTensorWatches() override;
// Insert special-purpose debug nodes to graph. See the documentation of
// DebugNodeInserter::InsertNodes() for details.
- Status InsertNodes(Graph* graph, Device* device);
+ Status InsertNodes(Graph* graph, Device* device) override;
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches;
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index a3b8b400a3..3e4589ab00 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -9,6 +9,7 @@ option java_package = "org.tensorflow.framework";
import "tensorflow/core/framework/cost_graph.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/step_stats.proto";
+import "tensorflow/core/protobuf/debug.proto";
message GPUOptions {
// A value between 0 and 1 that indicates what fraction of the
@@ -222,30 +223,6 @@ message ConfigProto {
int64 operation_timeout_in_ms = 11;
};
-// EXPERIMENTAL. Option for watching a node.
-message DebugTensorWatch {
- // Name of the node to watch.
- string node_name = 1;
-
- // Output slot to watch.
- // The semantics of output_slot == -1 is that the node is only watched for
- // completion, but not for any output tensors. See NodeCompletionCallback
- // in debug_gateway.h.
- // TODO(cais): Implement this semantics.
- int32 output_slot = 2;
-
- // Name(s) of the debugging op(s).
- // One or more than one probes on a tensor.
- // e.g., {"DebugIdentity", "DebugNanCount"}
- repeated string debug_ops = 3;
-
- // URL(s) for debug targets(s).
- // E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011"
- // Each debug op listed in debug_ops will publish its output tensor (debug
- // signal) to all URLs in debug_urls.
- repeated string debug_urls = 4;
-}
-
// EXPERIMENTAL. Options for a single Run() call.
message RunOptions {
// TODO(pbar) Turn this into a TraceOptions proto which allows
@@ -264,12 +241,14 @@ message RunOptions {
// The thread pool to use, if session_inter_op_thread_pool is configured.
int32 inter_op_thread_pool = 3;
- // Debugging options
- repeated DebugTensorWatch debug_tensor_watch_opts = 4;
-
// Whether the partition graph(s) executed by the executor(s) should be
// outputted via RunMetadata.
bool output_partition_graphs = 5;
+
+ // EXPERIMENTAL. Options used to intialize DebuggerState, if enabled.
+ DebugOptions debug_options = 6;
+
+ reserved 4;
}
// EXPERIMENTAL. Metadata output (i.e., non-Tensor) for a single Run() call.
diff --git a/tensorflow/core/protobuf/debug.proto b/tensorflow/core/protobuf/debug.proto
new file mode 100644
index 0000000000..5b32f9fc0b
--- /dev/null
+++ b/tensorflow/core/protobuf/debug.proto
@@ -0,0 +1,37 @@
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+option java_outer_classname = "DebugProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.framework";
+
+// EXPERIMENTAL. Option for watching a node.
+message DebugTensorWatch {
+ // Name of the node to watch.
+ string node_name = 1;
+
+ // Output slot to watch.
+ // The semantics of output_slot == -1 is that the node is only watched for
+ // completion, but not for any output tensors. See NodeCompletionCallback
+ // in debug_gateway.h.
+ // TODO(cais): Implement this semantics.
+ int32 output_slot = 2;
+
+ // Name(s) of the debugging op(s).
+ // One or more than one probes on a tensor.
+ // e.g., {"DebugIdentity", "DebugNanCount"}
+ repeated string debug_ops = 3;
+
+ // URL(s) for debug targets(s).
+ // E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011"
+ // Each debug op listed in debug_ops will publish its output tensor (debug
+ // signal) to all URLs in debug_urls.
+ repeated string debug_urls = 4;
+}
+
+// EXPERIMENTAL. Options for initializing DebuggerState.
+message DebugOptions {
+ // Debugging options
+ repeated DebugTensorWatch debug_tensor_watch_opts = 4;
+}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2841b3b560..0b88404560 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2018,6 +2018,7 @@ tf_py_wrap_cc(
"//tensorflow/c:checkpoint_reader",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:lib",
+ "//tensorflow/core/debug",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/tools/tfprof/internal:print_model_analysis",
"//util/python:python_headers",
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index ac1ece6af7..7d409aecca 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -992,7 +992,7 @@ class AnalyzerCLIWhileLoopTest(test_util.TensorFlowTestCase):
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_url = "file://%s" % cls._dump_root
- watch_opts = run_options.debug_tensor_watch_opts
+ watch_opts = run_options.debug_options.debug_tensor_watch_opts
# Add debug tensor watch for "while/Identity".
watch = watch_opts.add()
diff --git a/tensorflow/python/debug/debug_data.py b/tensorflow/python/debug/debug_data.py
index c5bcf0a904..027f8cf798 100644
--- a/tensorflow/python/debug/debug_data.py
+++ b/tensorflow/python/debug/debug_data.py
@@ -114,7 +114,7 @@ def _is_copy_node(node_name):
"""Determine whether a node name is that of a debug Copy node.
Such nodes are inserted by TensorFlow core upon request in
- RunOptions.debug_tensor_watch_opts.
+ RunOptions.debug_options.debug_tensor_watch_opts.
Args:
node_name: Name of the node.
@@ -130,7 +130,7 @@ def _is_debug_node(node_name):
"""Determine whether a node name is that of a debug node.
Such nodes are inserted by TensorFlow core upon request in
- RunOptions.debug_tensor_watch_opts.
+ RunOptions.debug_options.debug_tensor_watch_opts.
Args:
node_name: Name of the node.
diff --git a/tensorflow/python/debug/debug_utils.py b/tensorflow/python/debug/debug_utils.py
index d0ed071faf..ea672f9c51 100644
--- a/tensorflow/python/debug/debug_utils.py
+++ b/tensorflow/python/debug/debug_utils.py
@@ -41,7 +41,7 @@ def add_debug_tensor_watch(run_options,
string with only one element.
"""
- watch_opts = run_options.debug_tensor_watch_opts
+ watch_opts = run_options.debug_options.debug_tensor_watch_opts
watch = watch_opts.add()
watch.node_name = node_name
diff --git a/tensorflow/python/debug/debug_utils_test.py b/tensorflow/python/debug/debug_utils_test.py
index 9c5dccdac6..081c27a6c4 100644
--- a/tensorflow/python/debug/debug_utils_test.py
+++ b/tensorflow/python/debug/debug_utils_test.py
@@ -98,16 +98,16 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_utils.add_debug_tensor_watch(
self._run_options, "foo/node_b", 0, debug_urls="file:///tmp/tfdbg_2")
- self.assertEqual(2, len(self._run_options.debug_tensor_watch_opts))
+ debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
+ self.assertEqual(2, len(debug_watch_opts))
- watch_0 = self._run_options.debug_tensor_watch_opts[0]
- watch_1 = self._run_options.debug_tensor_watch_opts[1]
+ watch_0 = debug_watch_opts[0]
+ watch_1 = debug_watch_opts[1]
self.assertEqual("foo/node_a", watch_0.node_name)
self.assertEqual(1, watch_0.output_slot)
self.assertEqual("foo/node_b", watch_1.node_name)
self.assertEqual(0, watch_1.output_slot)
-
# Verify default debug op name.
self.assertEqual(["DebugIdentity"], watch_0.debug_ops)
self.assertEqual(["DebugIdentity"], watch_1.debug_ops)
@@ -124,9 +124,10 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_ops="DebugNanCount",
debug_urls="file:///tmp/tfdbg_1")
- self.assertEqual(1, len(self._run_options.debug_tensor_watch_opts))
+ debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
+ self.assertEqual(1, len(debug_watch_opts))
- watch_0 = self._run_options.debug_tensor_watch_opts[0]
+ watch_0 = debug_watch_opts[0]
self.assertEqual("foo/node_a", watch_0.node_name)
self.assertEqual(0, watch_0.output_slot)
@@ -145,9 +146,10 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_ops=["DebugNanCount", "DebugIdentity"],
debug_urls="file:///tmp/tfdbg_1")
- self.assertEqual(1, len(self._run_options.debug_tensor_watch_opts))
+ debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
+ self.assertEqual(1, len(debug_watch_opts))
- watch_0 = self._run_options.debug_tensor_watch_opts[0]
+ watch_0 = debug_watch_opts[0]
self.assertEqual("foo/node_a", watch_0.node_name)
self.assertEqual(0, watch_0.output_slot)
@@ -166,9 +168,10 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_ops="DebugNanCount",
debug_urls=["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"])
- self.assertEqual(1, len(self._run_options.debug_tensor_watch_opts))
+ debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
+ self.assertEqual(1, len(debug_watch_opts))
- watch_0 = self._run_options.debug_tensor_watch_opts[0]
+ watch_0 = debug_watch_opts[0]
self.assertEqual("foo/node_a", watch_0.node_name)
self.assertEqual(0, watch_0.output_slot)
@@ -187,13 +190,13 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_ops=["DebugIdentity", "DebugNanCount"],
debug_urls="file:///tmp/tfdbg_1")
- self.assertEqual(self._expected_num_nodes,
- len(self._run_options.debug_tensor_watch_opts))
+ debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
+ self.assertEqual(self._expected_num_nodes, len(debug_watch_opts))
# Verify that each of the nodes in the graph with output tensors in the
# graph have debug tensor watch.
- node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
- 0, ["DebugIdentity", "DebugNanCount"],
+ node_names = self._verify_watches(debug_watch_opts, 0,
+ ["DebugIdentity", "DebugNanCount"],
["file:///tmp/tfdbg_1"])
# Verify the node names.
@@ -218,9 +221,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_urls="file:///tmp/tfdbg_1",
node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)")
- node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
- 0, ["DebugIdentity"],
- ["file:///tmp/tfdbg_1"])
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
sorted(node_names))
@@ -232,9 +235,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_urls="file:///tmp/tfdbg_1",
op_type_regex_whitelist="(Variable|MatMul)")
- node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
- 0, ["DebugIdentity"],
- ["file:///tmp/tfdbg_1"])
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
def testWatchGraph_nodeNameAndOpTypeWhitelists(self):
@@ -245,9 +248,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
node_name_regex_whitelist="([a-z]+1$)",
op_type_regex_whitelist="(MatMul)")
- node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
- 0, ["DebugIdentity"],
- ["file:///tmp/tfdbg_1"])
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(["p1"], node_names)
def testWatchGraph_nodeNameBlacklist(self):
@@ -257,9 +260,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_urls="file:///tmp/tfdbg_1",
node_name_regex_blacklist="(a1$|a1_init$|a1/.*|p1$)")
- node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
- 0, ["DebugIdentity"],
- ["file:///tmp/tfdbg_1"])
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(
sorted(["b_init", "b", "b/Assign", "b/read", "c", "s"]),
sorted(node_names))
@@ -271,9 +274,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
debug_urls="file:///tmp/tfdbg_1",
op_type_regex_blacklist="(Variable|Identity|Assign|Const)")
- node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
- 0, ["DebugIdentity"],
- ["file:///tmp/tfdbg_1"])
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(sorted(["p1", "s"]), sorted(node_names))
def testWatchGraph_nodeNameAndOpTypeBlacklists(self):
@@ -284,9 +287,9 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
node_name_regex_blacklist="p1$",
op_type_regex_blacklist="(Variable|Identity|Assign|Const)")
- node_names = self._verify_watches(self._run_options.debug_tensor_watch_opts,
- 0, ["DebugIdentity"],
- ["file:///tmp/tfdbg_1"])
+ node_names = self._verify_watches(
+ self._run_options.debug_options.debug_tensor_watch_opts, 0,
+ ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(["s"], node_names)
diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py
index 14123d95f8..684d40e0f1 100644
--- a/tensorflow/python/debug/wrappers/hooks.py
+++ b/tensorflow/python/debug/wrappers/hooks.py
@@ -79,7 +79,7 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook,
self.on_run_end(on_run_end_request)
def _decorate_options_for_debug(self, options, graph):
- """Modify RunOptions.debug_tensor_watch_opts for debugging.
+ """Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.
Args:
options: (config_pb2.RunOptions) The RunOptions instance to be modified.
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 32077d198f..168bad8554 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -795,5 +795,5 @@ class _HookedSession(_WrappedSession):
options.output_partition_graphs,
incoming_options.output_partition_graphs)
- options.debug_tensor_watch_opts.extend(
- incoming_options.debug_tensor_watch_opts)
+ options.debug_options.debug_tensor_watch_opts.extend(
+ incoming_options.debug_options.debug_tensor_watch_opts)
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 59edb61f3a..8fe7084084 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -29,6 +29,7 @@ import tensorflow as tf
from tensorflow.contrib import testing
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.training import monitored_session
@@ -693,7 +694,8 @@ class RunOptionsMetadataHook(tf.train.SessionRunHook):
trace_level=self._trace_level,
timeout_in_ms=self._timeout_in_ms,
output_partition_graphs=self._output_partition_graphs)
- options.debug_tensor_watch_opts.extend([self._debug_tensor_watch])
+ options.debug_options.debug_tensor_watch_opts.extend(
+ [self._debug_tensor_watch])
return tf.train.SessionRunArgs(None, None, options=options)
def after_run(self, run_context, run_values):
@@ -1019,13 +1021,13 @@ class MonitoredSessionTest(tf.test.TestCase):
my_const = tf.constant(42, name='my_const')
_ = tf.constant(24, name='my_const_2')
- watch_a = config_pb2.DebugTensorWatch(
+ watch_a = debug_pb2.DebugTensorWatch(
node_name='my_const',
output_slot=0,
debug_ops=['DebugIdentity'],
debug_urls=[])
hook_a = RunOptionsMetadataHook(2, 30000, False, watch_a)
- watch_b = config_pb2.DebugTensorWatch(
+ watch_b = debug_pb2.DebugTensorWatch(
node_name='my_const_2',
output_slot=0,
debug_ops=['DebugIdentity'],
@@ -1044,7 +1046,8 @@ class MonitoredSessionTest(tf.test.TestCase):
trace_level=3,
timeout_in_ms=60000,
output_partition_graphs=True,
- debug_tensor_watch_opts=[watch_a, watch_b])
+ debug_options=debug_pb2.DebugOptions(
+ debug_tensor_watch_opts=[watch_a, watch_b]))
],
hook_b.run_options_list)
self.assertEqual(1, len(hook_b.run_metadata_list))
@@ -1059,21 +1062,22 @@ class MonitoredSessionTest(tf.test.TestCase):
my_const = tf.constant(42, name='my_const')
_ = tf.constant(24, name='my_const_2')
- hook_watch = config_pb2.DebugTensorWatch(
+ hook_watch = debug_pb2.DebugTensorWatch(
node_name='my_const_2',
output_slot=0,
debug_ops=['DebugIdentity'],
debug_urls=[])
hook = RunOptionsMetadataHook(2, 60000, False, hook_watch)
with tf.train.MonitoredSession(hooks=[hook]) as session:
- caller_watch = config_pb2.DebugTensorWatch(
+ caller_watch = debug_pb2.DebugTensorWatch(
node_name='my_const',
output_slot=0,
debug_ops=['DebugIdentity'],
debug_urls=[])
caller_options = config_pb2.RunOptions(
trace_level=3, timeout_in_ms=30000, output_partition_graphs=True)
- caller_options.debug_tensor_watch_opts.extend([caller_watch])
+ caller_options.debug_options.debug_tensor_watch_opts.extend(
+ [caller_watch])
self.assertEqual(42, session.run(my_const, options=caller_options))
# trace_level=3 from the caller should override 2 from the hook.
@@ -1088,7 +1092,8 @@ class MonitoredSessionTest(tf.test.TestCase):
trace_level=3,
timeout_in_ms=60000,
output_partition_graphs=True,
- debug_tensor_watch_opts=[caller_watch, hook_watch])
+ debug_options=debug_pb2.DebugOptions(
+ debug_tensor_watch_opts=[caller_watch, hook_watch]))
],
hook.run_options_list)
self.assertEqual(1, len(hook.run_metadata_list))