aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-04-29 08:26:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-29 09:48:44 -0700
commita25509eda3e42dddc88155965eaffe4b3c455af5 (patch)
tree79bc1880238cd291c64123c2c241122da1c000b5
parentea25bc496e30ecc1f90622906390669913d92e74 (diff)
Add TFDBG support to GrpcSession
* Along the way, unify the way the debugger works in DirectSession (non-distributed Sessions) and MasterSession (for distributed Sessions). * The SummarizDebugTensorWatches method is invoked in DirectSession::GetOrCreateExecutors() and MasterSession::HashBuildGraphOptions() method to generate keys for partition graphs and executors. * The DebugStateInterface::PublishDebugMetadata() method is used to send metadata about the debugged Session::Run() call to debug URLs. This happens in DirectSession::Run() and MasterSession::DoRunWithLocalExecution() respectively. * The DebugGraphDecoratorInterface::DecorateGraph() and DebugGraphDecoratorInterface::PublishGraph() methods are used to insert debug ops to the debugged graph and send the modified graph to debug URLs. This happens in DirectSession::GetOrCreateExecutors() and GraphMgr::InitItem(), respectively. Change: 154631802
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/common_runtime/build_graph_options.h3
-rw-r--r--tensorflow/core/common_runtime/debugger_state_interface.cc45
-rw-r--r--tensorflow/core/common_runtime/debugger_state_interface.h48
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc79
-rw-r--r--tensorflow/core/common_runtime/direct_session.h14
-rw-r--r--tensorflow/core/debug/BUILD41
-rw-r--r--tensorflow/core/debug/debug.cc10
-rw-r--r--tensorflow/core/debug/debug_graph_utils.cc64
-rw-r--r--tensorflow/core/debug/debug_graph_utils.h30
-rw-r--r--tensorflow/core/debug/debug_grpc_testlib.cc47
-rw-r--r--tensorflow/core/debug/debug_grpc_testlib.h2
-rw-r--r--tensorflow/core/debug/debugger_state_impl.cc66
-rw-r--r--tensorflow/core/debug/debugger_state_impl.h61
-rw-r--r--tensorflow/core/debug/grpc_session_debug_test.cc288
-rw-r--r--tensorflow/core/distributed_runtime/BUILD2
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc31
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h10
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc62
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h6
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc2
-rw-r--r--tensorflow/core/protobuf/worker.proto4
22 files changed, 761 insertions, 155 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 79fd7ec01e..1617addba0 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1570,6 +1570,7 @@ tf_cuda_library(
":lib_internal",
":proto_text",
":protos_all_cc",
+ "//tensorflow/core/debug:debug_graph_utils",
"//tensorflow/core/kernels:function_ops",
],
alwayslink = 1,
diff --git a/tensorflow/core/common_runtime/build_graph_options.h b/tensorflow/core/common_runtime/build_graph_options.h
index 49566c8fa8..5f0e8f170b 100644
--- a/tensorflow/core/common_runtime/build_graph_options.h
+++ b/tensorflow/core/common_runtime/build_graph_options.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/debug.pb.h"
namespace tensorflow {
@@ -35,6 +36,8 @@ struct BuildGraphOptions {
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
bool use_function_convention = false;
+ DebugOptions debug_options;
+
string DebugString() const;
};
diff --git a/tensorflow/core/common_runtime/debugger_state_interface.cc b/tensorflow/core/common_runtime/debugger_state_interface.cc
index 2e2fbcd7f4..73157ca05d 100644
--- a/tensorflow/core/common_runtime/debugger_state_interface.cc
+++ b/tensorflow/core/common_runtime/debugger_state_interface.cc
@@ -17,9 +17,40 @@ limitations under the License.
namespace tensorflow {
+// static
DebuggerStateFactory* DebuggerStateRegistry::factory_ = nullptr;
// static
+DebugGraphDecoratorFactory* DebugGraphDecoratorRegistry::factory_ = nullptr;
+
+const string SummarizeDebugTensorWatches(
+ const protobuf::RepeatedPtrField<DebugTensorWatch>& watches) {
+ std::ostringstream oss;
+
+ for (const DebugTensorWatch& watch : watches) {
+ string tensor_name =
+ strings::StrCat(watch.node_name(), ":", watch.output_slot());
+ if (watch.tolerate_debug_op_creation_failures()) {
+ oss << "(TOL)"; // Shorthand for "tolerate".
+ }
+ oss << tensor_name << "|";
+
+ for (const string& debug_op : watch.debug_ops()) {
+ oss << debug_op << ",";
+ }
+
+ oss << "@";
+ for (const string& debug_url : watch.debug_urls()) {
+ oss << debug_url << ",";
+ }
+
+ oss << ";";
+ }
+
+ return oss.str();
+}
+
+// static
void DebuggerStateRegistry::RegisterFactory(
const DebuggerStateFactory& factory) {
delete factory_;
@@ -34,4 +65,18 @@ std::unique_ptr<DebuggerStateInterface> DebuggerStateRegistry::CreateState(
: (*factory_)(debug_options);
}
+// static
+void DebugGraphDecoratorRegistry::RegisterFactory(
+ const DebugGraphDecoratorFactory& factory) {
+ delete factory_;
+ factory_ = new DebugGraphDecoratorFactory(factory);
+}
+
+// static
+std::unique_ptr<DebugGraphDecoratorInterface>
+DebugGraphDecoratorRegistry::CreateDecorator(const DebugOptions& options) {
+ return (factory_ == nullptr || *factory_ == nullptr) ? nullptr
+ : (*factory_)(options);
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/common_runtime/debugger_state_interface.h b/tensorflow/core/common_runtime/debugger_state_interface.h
index fb72f9fa3e..d182ce7092 100644
--- a/tensorflow/core/common_runtime/debugger_state_interface.h
+++ b/tensorflow/core/common_runtime/debugger_state_interface.h
@@ -18,28 +18,24 @@ limitations under the License.
#include <memory>
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/protobuf/debug.pb.h"
namespace tensorflow {
-class DebugOptions; // Defined in core/protobuf/debug.h.
-class Device;
-class Graph;
+// Returns a summary string for the list of debug tensor watches.
+const string SummarizeDebugTensorWatches(
+ const protobuf::RepeatedPtrField<DebugTensorWatch>& watches);
// An abstract interface for storing and retrieving debugging information.
class DebuggerStateInterface {
public:
virtual ~DebuggerStateInterface() {}
- // Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
- virtual const string SummarizeDebugTensorWatches() = 0;
-
- // Insert special-purpose debug nodes to graph and dump the graph for
- // record. See the documentation of DebugNodeInserter::InsertNodes() for
- // details.
- virtual Status DecorateGraphForDebug(Graph* graph, Device* device) = 0;
-
// Publish metadata about the debugged Session::Run() call.
//
// Args:
@@ -59,6 +55,19 @@ class DebuggerStateInterface {
const std::vector<string>& target_nodes) = 0;
};
+class DebugGraphDecoratorInterface {
+ public:
+ virtual ~DebugGraphDecoratorInterface() {}
+
+ // Insert special-purpose debug nodes to graph and dump the graph for
+ // record. See the documentation of DebugNodeInserter::InsertNodes() for
+ // details.
+ virtual Status DecorateGraph(Graph* graph, Device* device) = 0;
+
+ // Publish Graph to debug URLs.
+ virtual Status PublishGraph(const Graph& graph) = 0;
+};
+
typedef std::function<std::unique_ptr<DebuggerStateInterface>(
const DebugOptions& options)>
DebuggerStateFactory;
@@ -86,6 +95,23 @@ class DebuggerStateRegistry {
TF_DISALLOW_COPY_AND_ASSIGN(DebuggerStateRegistry);
};
+typedef std::function<std::unique_ptr<DebugGraphDecoratorInterface>(
+ const DebugOptions& options)>
+ DebugGraphDecoratorFactory;
+
+class DebugGraphDecoratorRegistry {
+ public:
+ static void RegisterFactory(const DebugGraphDecoratorFactory& factory);
+
+ static std::unique_ptr<DebugGraphDecoratorInterface> CreateDecorator(
+ const DebugOptions& options);
+
+ private:
+ static DebugGraphDecoratorFactory* factory_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DebugGraphDecoratorRegistry);
+};
+
} // end namespace tensorflow
#endif // TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index f208e4b78e..7c017f9584 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -370,6 +370,43 @@ Status DirectSession::Run(const NamedTensorList& inputs,
&run_metadata);
}
+Status DirectSession::CreateDebuggerState(
+ const DebugOptions& debug_options, int64 session_run_count,
+ int64 executor_step_count, const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_names,
+ std::unique_ptr<DebuggerStateInterface>* debugger_state) {
+ std::unique_ptr<DebuggerStateInterface> state =
+ DebuggerStateRegistry::CreateState(debug_options);
+ if (!state) {
+ return errors::Internal(
+ "Debugger options are set, but creation of debugger state failed. "
+ "It appears that debugger is not linked in this TensorFlow build.");
+ }
+
+ TF_RETURN_IF_ERROR(state->PublishDebugMetadata(
+ debug_options.global_step(), session_run_count, executor_step_count,
+ input_names, output_names, target_names));
+
+ *debugger_state = std::move(state);
+ return Status::OK();
+}
+
+Status DirectSession::DecorateAndPublishGraphForDebug(
+ const DebugOptions& debug_options, Graph* graph, Device* device) {
+ std::unique_ptr<DebugGraphDecoratorInterface> decorator =
+ DebugGraphDecoratorRegistry::CreateDecorator(debug_options);
+ if (!decorator) {
+ return errors::Internal(
+ "Debugger options are set, but creation of debug graph publisher ",
+ "failed.");
+ }
+
+ TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
+ TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph));
+ return Status::OK();
+}
+
Status DirectSession::Run(const RunOptions& run_options,
const NamedTensorList& inputs,
const std::vector<string>& output_names,
@@ -402,27 +439,21 @@ Status DirectSession::Run(const RunOptions& run_options,
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
- RunStateArgs run_state_args;
+ RunStateArgs run_state_args(run_options.debug_options());
Executor::Args args;
args.step_id = step_id_counter_.fetch_add(1);
- // EXPERIMENTAL: Options that allow the client to insert nodes into partition
- // graphs for debugging.
- if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
- run_state_args.debugger_state =
- DebuggerStateRegistry::CreateState(run_options.debug_options());
- }
-
TF_RETURN_IF_ERROR(
GetOrCreateExecutors(pool, input_tensor_names, output_names, target_nodes,
&executors_and_keys, &run_state_args));
const int64 executor_step_count = executors_and_keys->step_count.fetch_add(1);
- if (run_state_args.debugger_state) {
- TF_RETURN_IF_ERROR(run_state_args.debugger_state->PublishDebugMetadata(
- run_options.debug_options().global_step(), args.step_id,
- executor_step_count, input_tensor_names, output_names, target_nodes));
+ std::unique_ptr<DebuggerStateInterface> debugger_state;
+ if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
+ TF_RETURN_IF_ERROR(CreateDebuggerState(
+ run_options.debug_options(), args.step_id, executor_step_count,
+ input_tensor_names, output_names, target_nodes, &debugger_state));
}
// Configure a call frame for the step, which we use to feed and
@@ -629,7 +660,9 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
- RunStateArgs run_state_args;
+ // TODO(cais): TFDBG support for partial runs.
+ DebugOptions debug_options;
+ RunStateArgs run_state_args(debug_options);
run_state_args.is_partial_run = true;
TF_RETURN_IF_ERROR(GetOrCreateExecutors(pool, input_names, output_names,
target_nodes, &executors_and_keys,
@@ -960,14 +993,15 @@ Status DirectSession::GetOrCreateExecutors(
thread::ThreadPool* pool, gtl::ArraySlice<string> inputs,
gtl::ArraySlice<string> outputs, gtl::ArraySlice<string> target_nodes,
ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args) {
- string debug_tensor_watches_summary;
int64 handle_name_counter_value = -1;
if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
handle_name_counter_value = handle_name_counter_.fetch_add(1);
}
- if (run_state_args->debugger_state) {
- debug_tensor_watches_summary =
- run_state_args->debugger_state->SummarizeDebugTensorWatches();
+
+ string debug_tensor_watches_summary;
+ if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
+ debug_tensor_watches_summary = SummarizeDebugTensorWatches(
+ run_state_args->debug_options.debug_tensor_watch_opts());
}
// Fast lookup path, no sorting.
@@ -1032,6 +1066,9 @@ Status DirectSession::GetOrCreateExecutors(
options.fetch_endpoints = outputs_sorted;
options.target_nodes = tn_sorted;
options.use_function_convention = !run_state_args->is_partial_run;
+ if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
+ options.debug_options = run_state_args->debug_options;
+ }
std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@@ -1107,10 +1144,10 @@ Status DirectSession::GetOrCreateExecutors(
optimizer.Optimize(lib, options_.env, device, &iter->second);
- // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph
- if (run_state_args->debugger_state) {
- TF_RETURN_IF_ERROR(run_state_args->debugger_state->DecorateGraphForDebug(
- partition_graph.get(), params.device));
+ // EXPERIMENTAL: tfdbg inserts debug nodes in the graph.
+ if (!options.debug_options.debug_tensor_watch_opts().empty()) {
+ TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
+ options.debug_options, partition_graph.get(), params.device));
}
TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 061a7fa787..cc298b3e57 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -169,10 +169,12 @@ class DirectSession : public Session {
};
struct RunStateArgs {
+ RunStateArgs(const DebugOptions& options) : debug_options(options) {}
+
bool is_partial_run = false;
string handle;
std::unique_ptr<Graph> graph;
- std::unique_ptr<DebuggerStateInterface> debugger_state;
+ const DebugOptions& debug_options;
};
// Initializes the base execution state given the 'graph',
@@ -239,6 +241,16 @@ class DirectSession : public Session {
return ::tensorflow::Status::OK();
}
+ ::tensorflow::Status CreateDebuggerState(
+ const DebugOptions& debug_options, int64 session_run_count,
+ int64 executor_step_count, const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_names,
+ std::unique_ptr<DebuggerStateInterface>* debugger_state);
+
+ ::tensorflow::Status DecorateAndPublishGraphForDebug(
+ const DebugOptions& debug_options, Graph* graph, Device* device);
+
const SessionOptions options_;
// Device structures.
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 372ddb168a..c46e9e4c63 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -58,7 +58,7 @@ cc_library(
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
- ":debug_graph_utils",
+ ":debugger_state_impl",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:debug_ops_op_lib",
],
@@ -86,13 +86,25 @@ tf_cuda_library(
)
tf_cuda_library(
+ name = "debugger_state_impl",
+ srcs = ["debugger_state_impl.cc"],
+ hdrs = ["debugger_state_impl.h"],
+ copts = tf_copts(),
+ linkstatic = 1,
+ deps = [
+ ":debug_graph_utils",
+ ":debug_io_utils",
+ ],
+ alwayslink = 1,
+)
+
+tf_cuda_library(
name = "debug_graph_utils",
srcs = ["debug_graph_utils.cc"],
hdrs = ["debug_graph_utils.h"],
copts = tf_copts(),
linkstatic = 1,
deps = [
- ":debug_io_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -135,6 +147,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
"@grpc//:grpc++_unsecure",
],
alwayslink = 1,
@@ -209,6 +222,30 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "grpc_session_debug_test",
+ size = "medium",
+ srcs = ["grpc_session_debug_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":debug_grpc_testlib",
+ ":debug_io_utils",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_session",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
+ "//tensorflow/core/kernels:constant_op",
+ "//tensorflow/core/kernels:matmul_op",
+ ],
+)
+
# TODO(cais): Add the following back in when tfdbg is supported on Android.
# filegroup(
# name = "android_srcs",
diff --git a/tensorflow/core/debug/debug.cc b/tensorflow/core/debug/debug.cc
index c293b285c3..1aedfc2710 100644
--- a/tensorflow/core/debug/debug.cc
+++ b/tensorflow/core/debug/debug.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
-#include "tensorflow/core/debug/debug_graph_utils.h"
+#include "tensorflow/core/debug/debugger_state_impl.h"
namespace tensorflow {
namespace {
@@ -30,10 +30,18 @@ class DebuggerStateRegistration {
return std::unique_ptr<DebuggerStateInterface>(new DebuggerState(options));
}
+ static std::unique_ptr<DebugGraphDecoratorInterface>
+ CreateDebugGraphDecorator(const DebugOptions& options) {
+ return std::unique_ptr<DebugGraphDecoratorInterface>(
+ new DebugGraphDecorator(options));
+ }
+
DebuggerStateRegistration() {
DebuggerStateRegistry::RegisterFactory(CreateDebuggerState);
+ DebugGraphDecoratorRegistry::RegisterFactory(CreateDebugGraphDecorator);
}
};
+
static DebuggerStateRegistration register_debugger_state_implementation;
} // end namespace
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index 9061fd39f5..a222dc75d7 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/common_runtime/memory_types.h"
-#include "tensorflow/core/debug/debug_io_utils.h"
#include "tensorflow/core/framework/kernel_def.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -45,69 +44,6 @@ Status ParseBoolString(const string& bool_str, bool* bool_val) {
} // namespace
-DebuggerState::DebuggerState(const DebugOptions& debug_options)
- : watches(debug_options.debug_tensor_watch_opts()), debug_urls_() {
- for (const DebugTensorWatch& watch : watches) {
- for (const string& url : watch.debug_urls()) {
- debug_urls_.insert(url);
- }
- }
-}
-
-DebuggerState::~DebuggerState() {
- for (const string& debug_url : debug_urls_) {
- DebugIO::CloseDebugURL(debug_url).IgnoreError();
- }
-}
-
-const string DebuggerState::SummarizeDebugTensorWatches() {
- std::ostringstream oss;
-
- for (const DebugTensorWatch& watch : watches) {
- string tensor_name =
- strings::StrCat(watch.node_name(), ":", watch.output_slot());
- if (watch.tolerate_debug_op_creation_failures()) {
- oss << "(TOL)"; // Shorthand for "tolerate".
- }
- oss << tensor_name << "|";
-
- for (const string& debug_op : watch.debug_ops()) {
- oss << debug_op << ",";
- }
-
- oss << "@";
- for (const string& debug_url : watch.debug_urls()) {
- oss << debug_url << ",";
- }
-
- oss << ";";
- }
-
- return oss.str();
-}
-
-Status DebuggerState::DecorateGraphForDebug(Graph* graph, Device* device) {
- Status status;
-
- DebugNodeInserter::DeparallelizeWhileLoops(graph, device);
- status.Update(DebugNodeInserter::InsertNodes(watches, graph, device));
- if (status.ok()) {
- status.Update(DebugIO::PublishGraph(*graph, debug_urls_));
- }
-
- return status;
-}
-
-Status DebuggerState::PublishDebugMetadata(
- const int64 global_step, const int64 session_run_count,
- const int64 executor_step_count, const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes) {
- return DebugIO::PublishDebugMetadata(global_step, session_run_count,
- executor_step_count, input_names,
- output_names, target_nodes, debug_urls_);
-}
-
// static
Status DebugNodeInserter::InsertNodes(
const protobuf::RepeatedPtrField<DebugTensorWatch>& watches, Graph* graph,
diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h
index ac97856443..fa8b33b98a 100644
--- a/tensorflow/core/debug/debug_graph_utils.h
+++ b/tensorflow/core/debug/debug_graph_utils.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_DEBUG_NODE_INSERTER_H_
#include <unordered_map>
-#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
@@ -29,35 +28,6 @@ limitations under the License.
namespace tensorflow {
-class DebuggerState : public DebuggerStateInterface {
- public:
- DebuggerState(const DebugOptions& debug_options);
- virtual ~DebuggerState();
-
- // Returns a summary string for RepeatedPtrFields of DebugTensorWatches.
- const string SummarizeDebugTensorWatches() override;
-
- // Insert special-purpose debug nodes to graph. See the documentation of
- // DebugNodeInserter::InsertNodes() for details.
- Status DecorateGraphForDebug(Graph* graph, Device* device) override;
-
- const protobuf::RepeatedPtrField<DebugTensorWatch>& watches;
-
- // Publish metadata about the debugged Session::Run() call.
- //
- // See the doc string of DebuggerStateInterface::PublishDebugMetadata() for
- // details.
- Status PublishDebugMetadata(const int64 global_step,
- const int64 session_run_count,
- const int64 executor_step_count,
- const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_names) override;
-
- private:
- std::unordered_set<string> debug_urls_;
-};
-
class DebugNodeInserter {
public:
// EXPERIMENTAL: Insert special debug ops (e.g., DebugIdentity) to graph for
diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc
index 8dfa904489..d9fab87aed 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.cc
+++ b/tensorflow/core/debug/debug_grpc_testlib.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/debug/debug_io_utils.h"
+#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
@@ -33,26 +34,32 @@ namespace test {
Event event;
while (stream->Read(&event)) {
- const Summary::Value& val = event.summary().value(0);
-
- std::vector<string> name_items =
- tensorflow::str_util::Split(val.node_name(), ':');
-
- const string node_name = name_items[0];
- int32 output_slot = 0;
- tensorflow::strings::safe_strto32(name_items[1], &output_slot);
- const string debug_op = name_items[2];
-
- const TensorProto& tensor_proto = val.tensor();
- Tensor tensor(tensor_proto.dtype());
- if (!tensor.FromProto(tensor_proto)) {
- return ::grpc::Status::CANCELLED;
+ if (event.has_log_message()) {
+ debug_metadata_strings.push_back(event.log_message().message());
+ } else if (!event.graph_def().empty()) {
+ encoded_graph_defs.push_back(event.graph_def());
+ } else if (event.has_summary()) {
+ const Summary::Value& val = event.summary().value(0);
+
+ std::vector<string> name_items =
+ tensorflow::str_util::Split(val.node_name(), ':');
+
+ const string node_name = name_items[0];
+ int32 output_slot = 0;
+ tensorflow::strings::safe_strto32(name_items[1], &output_slot);
+ const string debug_op = name_items[2];
+
+ const TensorProto& tensor_proto = val.tensor();
+ Tensor tensor(tensor_proto.dtype());
+ if (!tensor.FromProto(tensor_proto)) {
+ return ::grpc::Status::CANCELLED;
+ }
+
+ node_names.push_back(node_name);
+ output_slots.push_back(output_slot);
+ debug_ops.push_back(debug_op);
+ debug_tensors.push_back(tensor);
}
-
- node_names.push_back(node_name);
- output_slots.push_back(output_slot);
- debug_ops.push_back(debug_op);
- debug_tensors.push_back(tensor);
}
{
@@ -79,6 +86,8 @@ namespace test {
}
void TestEventListenerImpl::ClearReceivedDebugData() {
+ debug_metadata_strings.clear();
+ encoded_graph_defs.clear();
node_names.clear();
output_slots.clear();
debug_ops.clear();
diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h
index 0e3223dbfe..c2b96e78c5 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.h
+++ b/tensorflow/core/debug/debug_grpc_testlib.h
@@ -47,6 +47,8 @@ class TestEventListenerImpl final : public EventListener::Service {
const int32 output_slot,
const string& debug_op);
+ std::vector<string> debug_metadata_strings;
+ std::vector<string> encoded_graph_defs;
std::vector<string> node_names;
std::vector<int32> output_slots;
std::vector<string> debug_ops;
diff --git a/tensorflow/core/debug/debugger_state_impl.cc b/tensorflow/core/debug/debugger_state_impl.cc
new file mode 100644
index 0000000000..d5752c0002
--- /dev/null
+++ b/tensorflow/core/debug/debugger_state_impl.cc
@@ -0,0 +1,66 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/debug/debugger_state_impl.h"
+
+#include "tensorflow/core/debug/debug_graph_utils.h"
+#include "tensorflow/core/debug/debug_io_utils.h"
+
+namespace tensorflow {
+
+DebuggerState::DebuggerState(const DebugOptions& debug_options) {
+ for (const DebugTensorWatch& watch :
+ debug_options.debug_tensor_watch_opts()) {
+ for (const string& url : watch.debug_urls()) {
+ debug_urls_.insert(url);
+ }
+ }
+}
+
+DebuggerState::~DebuggerState() {
+ for (const string& debug_url : debug_urls_) {
+ DebugIO::CloseDebugURL(debug_url).IgnoreError();
+ }
+}
+
+Status DebuggerState::PublishDebugMetadata(
+ const int64 global_step, const int64 session_run_count,
+ const int64 executor_step_count, const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_names) {
+ return DebugIO::PublishDebugMetadata(global_step, session_run_count,
+ executor_step_count, input_names,
+ output_names, target_names, debug_urls_);
+}
+
+Status DebugGraphDecorator::DecorateGraph(Graph* graph, Device* device) {
+ DebugNodeInserter::DeparallelizeWhileLoops(graph, device);
+ return DebugNodeInserter::InsertNodes(
+ debug_options_.debug_tensor_watch_opts(), graph, device);
+}
+
+Status DebugGraphDecorator::PublishGraph(const Graph& graph) {
+ std::unordered_set<string> debug_urls;
+ for (const DebugTensorWatch& watch :
+ debug_options_.debug_tensor_watch_opts()) {
+ for (const string& url : watch.debug_urls()) {
+ debug_urls.insert(url);
+ }
+ }
+
+ return DebugIO::PublishGraph(graph, debug_urls);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h
new file mode 100644
index 0000000000..d91aa03426
--- /dev/null
+++ b/tensorflow/core/debug/debugger_state_impl.h
@@ -0,0 +1,61 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_DEBUGGER_STATE_IMPL_H_
+#define TENSORFLOW_DEBUGGER_STATE_IMPL_H_
+
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
+
+#include <unordered_set>
+#include <vector>
+
+namespace tensorflow {
+
+class DebuggerState : public DebuggerStateInterface {
+ public:
+ DebuggerState(const DebugOptions& debug_options);
+ virtual ~DebuggerState();
+
+ // Publish metadata about the debugged Session::Run() call.
+ //
+ // See the doc string of DebuggerStateInterface::PublishDebugMetadata() for
+ // details.
+ Status PublishDebugMetadata(const int64 global_step,
+ const int64 session_run_count,
+ const int64 executor_step_count,
+ const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_names) override;
+
+ private:
+ std::unordered_set<string> debug_urls_;
+};
+
+class DebugGraphDecorator : public DebugGraphDecoratorInterface {
+ public:
+ DebugGraphDecorator(const DebugOptions& debug_options)
+ : debug_options_(debug_options) {}
+ virtual ~DebugGraphDecorator() {}
+
+ Status DecorateGraph(Graph* graph, Device* device) override;
+ Status PublishGraph(const Graph& graph) override;
+
+ private:
+ DebugOptions debug_options_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_DEBUGGER_STATE_IMPL_H_
diff --git a/tensorflow/core/debug/grpc_session_debug_test.cc b/tensorflow/core/debug/grpc_session_debug_test.cc
new file mode 100644
index 0000000000..6c68729410
--- /dev/null
+++ b/tensorflow/core/debug/grpc_session_debug_test.cc
@@ -0,0 +1,288 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/debug/debug_io_utils.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/default_device.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/debug.pb.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/port.h"
+
+namespace tensorflow {
+
+static SessionOptions Devices(int num_cpus, int num_gpus) {
+ SessionOptions result;
+ (*result.config.mutable_device_count())["CPU"] = num_cpus;
+ (*result.config.mutable_device_count())["GPU"] = num_gpus;
+ return result;
+}
+
+void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({1, 2}));
+ test::FillValues<float>(&a_tensor, {1.0, 2.0});
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ node_names[0] = a->name();
+
+ Tensor b_tensor(DT_FLOAT, TensorShape({2, 1}));
+ test::FillValues<float>(&b_tensor, {2.0, 1.0});
+ Node* b = test::graph::Constant(&graph, b_tensor);
+ node_names[1] = b->name();
+
+ // c = a * b
+ Node* c = test::graph::Matmul(&graph, a, b, false, false);
+ node_names[2] = c->name();
+
+ test::graph::ToGraphDef(&graph, graph_def);
+}
+
+// Asserts that "val" is a single float tensor. The only float is
+// "expected_val".
+static void IsSingleFloatValue(const Tensor& val, float expected_val) {
+ ASSERT_EQ(val.dtype(), DT_FLOAT);
+ ASSERT_EQ(val.NumElements(), 1);
+ ASSERT_EQ(val.flat<float>()(0), expected_val);
+}
+
+static SessionOptions Options(const string& target, int placement_period) {
+ SessionOptions options;
+ // NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
+ // string.
+ options.target = strings::StrCat("grpc://", target);
+ options.config.set_placement_period(placement_period);
+ options.config.mutable_graph_options()
+ ->mutable_optimizer_options()
+ ->set_opt_level(OptimizerOptions::L0);
+ return options;
+}
+
+static Session* NewRemote(const SessionOptions& options) {
+ return CHECK_NOTNULL(NewSession(options));
+}
+
+class GrpcSessionDebugTest : public ::testing::Test {
+ protected:
+ void SetUp() override { CreateDumpDir(); }
+
+ void TearDown() override { DeleteDumpDir(); }
+
+ void DeleteDumpDir() {
+ if (Env::Default()->IsDirectory(dump_dir_).ok()) {
+ int64 undeleted_files = 0;
+ int64 undeleted_dirs = 0;
+ ASSERT_TRUE(
+ Env::Default()
+ ->DeleteRecursively(dump_dir_, &undeleted_files, &undeleted_dirs)
+ .ok());
+ ASSERT_EQ(0, undeleted_files);
+ ASSERT_EQ(0, undeleted_dirs);
+ }
+ }
+
+ const string GetDebugURL() { return debug_url_; }
+
+ void LoadTensorDumps(const string& subdir, std::vector<Tensor>* tensors) {
+ const string dirpath = io::JoinPath(dump_dir_, subdir);
+ if (!(Env::Default()->IsDirectory(dirpath).ok())) {
+ return;
+ }
+
+ std::vector<string> filenames;
+ TF_ASSERT_OK(Env::Default()->GetChildren(dirpath, &filenames));
+
+ for (const string& filename : filenames) {
+ Event event;
+ TF_ASSERT_OK(ReadEventFromFile(io::JoinPath(dirpath, filename), &event));
+ if (event.summary().value().size() == 1) {
+ Tensor tensor;
+ ASSERT_TRUE(tensor.FromProto(event.summary().value(0).tensor()));
+ tensors->push_back(tensor);
+ }
+ }
+ }
+
+ private:
+ void CreateDumpDir() {
+ char dir_template[] = "/tmp/tfdbg_grpc_sessions_XXXXXX";
+ dump_dir_ = mkdtemp(dir_template);
+ debug_url_ = strings::StrCat("file://", dump_dir_);
+ }
+
+ string dump_dir_;
+ string debug_url_;
+};
+
+TEST_F(GrpcSessionDebugTest, FileDebugURL) {
+ GraphDef graph;
+ string node_names[3];
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+ TF_CHECK_OK(session->Create(graph));
+
+ // Iteration 0: No watch.
+ // Iterations 1 and 2: Watch one Tensor.
+ // Iterations 3 and 4: Watch two Tensors.
+ // Iteration 5: No watch.
+ for (size_t i = 0; i < 6; ++i) {
+ RunOptions options;
+ if (i >= 1 && i < 5) {
+ DebugOptions* debug_options = options.mutable_debug_options();
+ DebugTensorWatch* watch = debug_options->add_debug_tensor_watch_opts();
+ watch->set_node_name(node_names[0]);
+ watch->set_output_slot(0);
+ watch->add_debug_ops("DebugIdentity");
+ watch->add_debug_urls(GetDebugURL());
+
+ if (i >= 3) {
+ watch = debug_options->add_debug_tensor_watch_opts();
+ watch->set_node_name(node_names[1]);
+ watch->set_output_slot(0);
+ watch->add_debug_ops("DebugIdentity");
+ watch->add_debug_urls(GetDebugURL());
+ }
+ }
+
+ RunMetadata metadata;
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(
+ session->Run(options, {}, {node_names[2]}, {}, &outputs, &metadata));
+ ASSERT_EQ(1, outputs.size());
+ IsSingleFloatValue(outputs[0], 4.0);
+
+ std::vector<Tensor> dumped_tensors;
+ LoadTensorDumps("n", &dumped_tensors);
+
+ if (i == 0 || i == 5) {
+ ASSERT_EQ(0, dumped_tensors.size());
+ } else {
+ if (i == 1 || i == 2) {
+ ASSERT_EQ(1, dumped_tensors.size());
+ ASSERT_EQ(TensorShape({1, 2}), dumped_tensors[0].shape());
+ ASSERT_EQ(1.0, dumped_tensors[0].flat<float>()(0));
+ ASSERT_EQ(2.0, dumped_tensors[0].flat<float>()(1));
+ } else {
+ ASSERT_EQ(2, dumped_tensors.size());
+ }
+ DeleteDumpDir();
+ }
+ }
+ TF_CHECK_OK(session->Close());
+}
+
+void SetDevice(GraphDef* graph, const string& name, const string& dev) {
+ for (size_t i = 0; i < graph->node_size(); ++i) {
+ if (graph->node(i).name() == name) {
+ graph->mutable_node(i)->set_device(dev);
+ return;
+ }
+ }
+ LOG(FATAL) << "Name '" << name << "' not found.";
+}
+
+TEST_F(GrpcSessionDebugTest, MultiDevices_String) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1000)));
+ ASSERT_TRUE(session != nullptr);
+
+ // b = a
+ Graph graph(OpRegistry::Global());
+ Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
+ for (size_t i = 0; i < 4; ++i) {
+ a_tensor.flat<string>()(i) = "hello, world";
+ }
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ Node* b = test::graph::Identity(&graph, a);
+
+ GraphDef def;
+ test::graph::ToGraphDef(&graph, &def);
+
+ // In this test, we force each node (a, b) on every possible device.
+ // We test all possible cases.
+ for (const auto& a_dev : cluster->devices()) {
+ for (const auto& b_dev : cluster->devices()) {
+ LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name();
+ SetDevice(&def, a->name(), a_dev.name());
+ SetDevice(&def, b->name(), b_dev.name());
+
+ Status s = session->Create(def);
+ if (s.ok()) {
+ std::vector<Tensor> outputs;
+
+ RunOptions options;
+ DebugOptions* debug_options = options.mutable_debug_options();
+ DebugTensorWatch* watch = debug_options->add_debug_tensor_watch_opts();
+ watch->set_node_name(a->name());
+ watch->set_output_slot(0);
+ watch->add_debug_ops("DebugIdentity");
+ watch->add_debug_urls(GetDebugURL());
+
+ RunMetadata metadata;
+ TF_CHECK_OK(
+ session->Run(options, {}, {b->name()}, {}, &outputs, &metadata));
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_EQ(outputs[0].dtype(), DT_STRING);
+ ASSERT_EQ(outputs[0].NumElements(), 4);
+ for (size_t i = 0; i < outputs[0].NumElements(); ++i) {
+ EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
+ }
+ TF_CHECK_OK(session->Close());
+
+ std::vector<Tensor> dumped_tensors;
+ LoadTensorDumps("n", &dumped_tensors);
+ ASSERT_EQ(1, dumped_tensors.size());
+ ASSERT_EQ(TensorShape({2, 2}), dumped_tensors[0].shape());
+ for (size_t i = 0; i < 4; ++i) {
+ ASSERT_EQ("hello, world", dumped_tensors[0].flat<string>()(i));
+ }
+
+ DeleteDumpDir();
+ } else {
+ LOG(ERROR) << "Error: " << s;
+ ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
+ (b_dev.device_type() == DEVICE_GPU));
+ ASSERT_FALSE(s.ok());
+ }
+ }
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 0c2d2b5d5d..0f5eb0cb32 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -259,6 +259,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/debug:debug_graph_utils",
],
)
@@ -329,6 +330,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
+ "//tensorflow/core/debug",
],
)
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 36b7b5b628..6051f11fad 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/constant_folding.h"
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -94,6 +95,21 @@ static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
return Status::OK();
}
+Status GraphMgr::DecorateAndPublishGraphForDebug(
+ const DebugOptions& debug_options, Graph* graph, Device* device) {
+ std::unique_ptr<DebugGraphDecoratorInterface> decorator =
+ DebugGraphDecoratorRegistry::CreateDecorator(debug_options);
+ if (!decorator) {
+ return errors::Internal(
+ "Debugger options are set, but creation of debug graph publisher ",
+ "failed.");
+ }
+ TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
+ TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph));
+
+ return Status::OK();
+}
+
// Creates executors given a graph definition "gdef" of a "session".
// If a node in "gdef" is shared by other graphs in "session", the
// same op kernel is reused. E.g., typically a params node is shared
@@ -106,7 +122,8 @@ static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
// "executors" are filled with one executor per device if success and
// the caller takes the ownership of returned executors.
Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
- const GraphOptions& graph_options, Item* item) {
+ const GraphOptions& graph_options,
+ const DebugOptions& debug_options, Item* item) {
item->session = session;
item->lib_def =
new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library());
@@ -232,6 +249,13 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
};
optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph);
+
+ // EXPERIMENTAL: tfdbg inserts debug nodes (i.e., probes) to the graph.
+ if (!debug_options.debug_tensor_watch_opts().empty()) {
+ TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
+ debug_options, subgraph.get(), params.device));
+ }
+
TF_RETURN_IF_ERROR(
EnsureMemoryTypes(DeviceType(unit->device->device_type()),
unit->device->name(), subgraph.get()));
@@ -247,9 +271,10 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
}
Status GraphMgr::Register(const string& session, const GraphDef& gdef,
- const GraphOptions& graph_options, string* handle) {
+ const GraphOptions& graph_options,
+ const DebugOptions& debug_options, string* handle) {
Item* item = new Item;
- Status s = InitItem(session, gdef, graph_options, item);
+ Status s = InitItem(session, gdef, graph_options, debug_options, item);
if (!s.ok()) {
item->Unref();
return s;
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index 5f51d63857..349af6c54e 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/protobuf/debug.pb.h"
namespace tensorflow {
@@ -67,7 +68,8 @@ class GraphMgr {
// Registers a graph. Fills in "handle"
Status Register(const string& session, const GraphDef& gdef,
- const GraphOptions& graph_options, string* handle);
+ const GraphOptions& graph_options,
+ const DebugOptions& debug_options, string* handle);
// Executes one step of a registered graph "handle".
//
@@ -167,7 +169,11 @@ class GraphMgr {
const StatusCallback& done);
Status InitItem(const string& session, const GraphDef& gdef,
- const GraphOptions& graph_options, Item* item);
+ const GraphOptions& graph_options,
+ const DebugOptions& debug_options, Item* item);
+
+ Status DecorateAndPublishGraphForDebug(const DebugOptions& debug_options,
+ Graph* graph, Device* device);
TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
};
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 73d4e6ab00..681933adad 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/profile_handler.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
+#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/distributed_runtime/scheduler.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
@@ -67,6 +68,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
+ debug_opts_(bopts.debug_options),
worker_cache_(worker_cache) {
VLOG(1) << "Created ReffedClientGraph for node with "
<< client_graph_->graph.num_node_ids();
@@ -206,6 +208,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const std::unique_ptr<SimpleClientGraph> client_graph_;
const SessionOptions session_opts_;
const bool is_partial_;
+ const DebugOptions& debug_opts_;
WorkerCacheInterface* const worker_cache_; // Not owned.
std::unordered_map<StringPiece, Node*, StringPiece::Hasher> name_to_node_;
@@ -406,6 +409,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
// For simplicity, we ship the library completely to every worker.
*c->req.mutable_graph_def()->mutable_library() = func_def_lib;
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
+ *c->req.mutable_debug_options() = debug_opts_;
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -899,6 +903,10 @@ void BuildBuildGraphOptions(const RunStepRequestWrapper& req,
opts->target_nodes.push_back(req.target_name(i));
}
+ if (!req.options().debug_options().debug_tensor_watch_opts().empty()) {
+ opts->debug_options = req.options().debug_options();
+ }
+
std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
@@ -916,6 +924,8 @@ void BuildBuildGraphOptions(const PartialRunSetupRequest& req,
opts->target_nodes.push_back(target);
}
+ // TODO(cais): Add TFDBG support to partial runs.
+
std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
@@ -932,6 +942,13 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
for (const string& name : opts.fetch_endpoints) {
h = Hash64(name.c_str(), name.size(), h);
}
+
+ if (!opts.debug_options.debug_tensor_watch_opts().empty()) {
+ const string watch_summary = SummarizeDebugTensorWatches(
+ opts.debug_options.debug_tensor_watch_opts());
+ h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
+ }
+
return h;
}
@@ -1069,6 +1086,7 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial,
env_->worker_cache);
+
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
@@ -1315,6 +1333,43 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
return s;
}
+Status MasterSession::CreateDebuggerState(
+ const DebugOptions& debug_options, const RunStepRequestWrapper& req,
+ int64 rcg_execution_count,
+ std::unique_ptr<DebuggerStateInterface>* debugger_state) {
+ std::unique_ptr<DebuggerStateInterface> state =
+ DebuggerStateRegistry::CreateState(debug_options);
+ if (!state) {
+ return errors::Internal(
+ "Debugger options are set, but creation of debugger state failed. "
+ "It appears that debugger is not linked in this TensorFlow build.");
+ }
+
+ std::vector<string> input_names;
+ for (size_t i = 0; i < req.num_feeds(); ++i) {
+ input_names.push_back(req.feed_name(i));
+ }
+ std::vector<string> output_names;
+ for (size_t i = 0; i < req.num_fetches(); ++i) {
+ output_names.push_back(req.fetch_name(i));
+ }
+ std::vector<string> target_names;
+ for (size_t i = 0; i < req.num_targets(); ++i) {
+ target_names.push_back(req.target_name(i));
+ }
+
+ // TODO(cais): We currently use -1 as a dummy value for session run count.
+ // While this counter value is straightforward to define and obtain for
+ // DirectSessions, it is less so for non-direct Sessions. Devise a better
+ // way to get its value when the need arises.
+ TF_RETURN_IF_ERROR(state->PublishDebugMetadata(
+ debug_options.global_step(), -1, rcg_execution_count, input_names,
+ output_names, target_names));
+
+ *debugger_state = std::move(state);
+ return Status::OK();
+}
+
Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
@@ -1333,6 +1388,13 @@ Status MasterSession::DoRunWithLocalExecution(
// Unref "rcg" when out of scope.
core::ScopedUnref unref(rcg);
+ std::unique_ptr<DebuggerStateInterface> debugger_state;
+ const DebugOptions& debug_options = req.options().debug_options();
+
+ if (!debug_options.debug_tensor_watch_opts().empty()) {
+ TF_RETURN_IF_ERROR(
+ CreateDebuggerState(debug_options, req, count, &debugger_state));
+ }
TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
// Keeps the highest 8 bits 0x01: we reserve some bits of the
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index ee1a340c8e..d47125be99 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <atomic>
#include <vector>
+#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/simple_graph_execution_state.h"
#include "tensorflow/core/common_runtime/stats_publisher_interface.h"
@@ -193,6 +194,11 @@ class MasterSession : public core::RefCounted {
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);
+ Status CreateDebuggerState(
+ const DebugOptions& debug_options, const RunStepRequestWrapper& req,
+ int64 rcg_execution_count,
+ std::unique_ptr<DebuggerStateInterface>* debugger_state);
+
TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
};
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 1aced4443f..89639e21b5 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -55,7 +55,7 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
env_->session_mgr->WorkerSessionForSession(request->session_handle());
Status s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(), request->graph_options(),
- response->mutable_graph_handle());
+ request->debug_options(), response->mutable_graph_handle());
if (s.ok()) {
env_->session_mgr->AssociateGraphWithSession(request->session_handle(),
response->graph_handle());
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index e3af1119e9..661327847c 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -28,6 +28,7 @@ import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/protobuf/config.proto";
+import "tensorflow/core/protobuf/debug.proto";
import "tensorflow/core/protobuf/named_tensor.proto";
import "tensorflow/core/protobuf/tensorflow_server.proto";
@@ -92,6 +93,9 @@ message RegisterGraphRequest {
// Configuration options for the session in which this graph was created.
GraphOptions graph_options = 4;
+
+ // Field(s) used by TensorFlow Debugger (tfdbg).
+ DebugOptions debug_options = 5;
}
message RegisterGraphResponse {