aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/debug
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-19 18:55:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 19:02:27 -0700
commit5ce3523bcc844217b47e7f862c1bed894cbaa34e (patch)
treedf131f41215ac39e448e5931da5d6a72605a8b16 /tensorflow/core/debug
parent7ad8e25495a2793ea14189359af736d2c662a694 (diff)
Extending the core DebugIdentity tensorflow operation with support for writing
to a singleton in memory datastructure that records a mapping from debug_urls to debug events. This simplifies reading a large number of states without writing to disk or making internal RPC calls for arbitrary nodes. PiperOrigin-RevId: 169337269
Diffstat (limited to 'tensorflow/core/debug')
-rw-r--r--tensorflow/core/debug/BUILD33
-rw-r--r--tensorflow/core/debug/debug_callback_registry.cc49
-rw-r--r--tensorflow/core/debug/debug_callback_registry.h71
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc43
-rw-r--r--tensorflow/core/debug/debug_io_utils.h24
-rw-r--r--tensorflow/core/debug/debug_io_utils_test.cc34
-rw-r--r--tensorflow/core/debug/debug_node_key.cc53
-rw-r--r--tensorflow/core/debug/debug_node_key.h51
8 files changed, 303 insertions, 55 deletions
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index 0eafe35e38..525f96a3de 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -7,6 +7,10 @@
# DebuggerState to be constructed at initialization time, enabling
# TensorFlow Debugger (tfdbg) support. For details, please see
# core/common_runtime/debugger_state_interface.h.
+# ":debug_callback_registry" - Depending on this target exposes a global
+# callback registry that will be used to record any observed tensors matching
+# a watch state.
+# ":debug_node_key" - Defines a struct used for tracking tensors.
package(
default_visibility = ["//tensorflow:internal"],
@@ -134,6 +138,8 @@ tf_cuda_library(
copts = tf_copts(),
linkstatic = 1,
deps = [
+ ":debug_callback_registry",
+ ":debug_node_key",
":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
"//tensorflow/core:core_cpu_internal",
@@ -167,6 +173,18 @@ tf_cuda_library(
alwayslink = 1,
)
+tf_cuda_library(
+ name = "debug_node_key",
+ srcs = ["debug_node_key.cc"],
+ hdrs = ["debug_node_key.h"],
+ copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
# TODO(cais): Fix flakiness on GPU and change this back to a tf_cc_test_gpu.
# See b/34081273.
tf_cc_test(
@@ -206,8 +224,10 @@ tf_cc_test(
srcs = ["debug_io_utils_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
+ ":debug_callback_registry",
":debug_grpc_testlib",
":debug_io_utils",
+ ":debug_node_key",
":debug_service_proto_cc",
":debugger_event_metadata_proto_cc",
"//tensorflow/core:core_cpu",
@@ -286,6 +306,19 @@ tf_cc_test(
],
)
+cc_library(
+ name = "debug_callback_registry",
+ srcs = ["debug_callback_registry.cc"],
+ hdrs = ["debug_callback_registry.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":debug_node_key",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
# TODO(cais): Add the following back in when tfdbg is supported on Android.
# filegroup(
# name = "android_srcs",
diff --git a/tensorflow/core/debug/debug_callback_registry.cc b/tensorflow/core/debug/debug_callback_registry.cc
new file mode 100644
index 0000000000..97967a3f04
--- /dev/null
+++ b/tensorflow/core/debug/debug_callback_registry.cc
@@ -0,0 +1,49 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/debug/debug_callback_registry.h"
+
+namespace tensorflow {
+
+DebugCallbackRegistry::DebugCallbackRegistry() {}
+
+/*static */ DebugCallbackRegistry* DebugCallbackRegistry::instance_ = nullptr;
+
+DebugCallbackRegistry* DebugCallbackRegistry::singleton() {
+ if (instance_ == nullptr) {
+ instance_ = new DebugCallbackRegistry();
+ }
+ return instance_;
+}
+
+void DebugCallbackRegistry::RegisterCallback(const string& key,
+ EventCallback callback) {
+ mutex_lock lock(mu_);
+ keyed_callback_[key] = std::move(callback);
+}
+
+DebugCallbackRegistry::EventCallback* DebugCallbackRegistry::GetCallback(
+ const string& key) {
+ mutex_lock lock(mu_);
+ auto iter = keyed_callback_.find(key);
+ return iter == keyed_callback_.end() ? nullptr : &iter->second;
+}
+
+void DebugCallbackRegistry::UnregisterCallback(const string& key) {
+ mutex_lock lock(mu_);
+ keyed_callback_.erase(key);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_callback_registry.h b/tensorflow/core/debug/debug_callback_registry.h
new file mode 100644
index 0000000000..8f08c656c2
--- /dev/null
+++ b/tensorflow/core/debug/debug_callback_registry.h
@@ -0,0 +1,71 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
+#define TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
+
+#include <functional>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/debug/debug_node_key.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+// Supports exporting observed debug events to clients using registered
+// callbacks. Users can register a callback for each debug_url stored using
+// DebugTensorWatch. The callback key be equivalent to what follows
+// "memcbk:///".
+//
+// All events generated for a watched node will be sent to the call back in the
+// order that they are observed.
+//
+// This callback router should not be used in production or training steps. It
+// is optimized for deep inspection of graph state rather than performance.
+class DebugCallbackRegistry {
+ public:
+ using EventCallback = std::function<void(const DebugNodeKey&, const Tensor&)>;
+
+ // Provides singleton access to the in memory event store.
+ static DebugCallbackRegistry* singleton();
+
+ // Returns the registered callback, or nullptr, for key.
+ EventCallback* GetCallback(const string& key);
+
+ // Associates callback with key. This must be called by clients observing
+ // nodes to be exported by this callback router before running a session.
+ void RegisterCallback(const string& key, EventCallback callback);
+
+ // Removes the callback associated with key.
+ void UnregisterCallback(const string& key);
+
+ private:
+ DebugCallbackRegistry();
+
+ // Mutex to ensure that keyed events are never updated in parallel.
+ mutex mu_;
+
+ // Maps debug_url keys to callbacks for routing observed tensors.
+ std::map<string, EventCallback> keyed_callback_ GUARDED_BY(mu_);
+
+ static DebugCallbackRegistry* instance_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 546cde4c16..86f66f909e 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -29,6 +29,7 @@ limitations under the License.
#pragma comment(lib, "Ws2_32.lib")
#endif // #ifndef PLATFORM_WINDOWS
+#include "tensorflow/core/debug/debug_callback_registry.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
@@ -280,35 +281,12 @@ Status PublishEncodedGraphDefInChunks(const string& encoded_graph_def,
const char* const DebugIO::kDebuggerPluginName = "debugger";
-const char* const DebugIO::kMetadataFilePrefix = "_tfdbg_";
-
const char* const DebugIO::kCoreMetadataTag = "core_metadata_";
-const char* const DebugIO::kDeviceTag = "device_";
-
const char* const DebugIO::kGraphTag = "graph_";
const char* const DebugIO::kHashTag = "hash";
-DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
- const int32 output_slot, const string& debug_op)
- : device_name(device_name),
- node_name(node_name),
- output_slot(output_slot),
- debug_op(debug_op),
- debug_node_name(
- strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
- device_path(DeviceNameToDevicePath(device_name)) {}
-
-bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
- return (device_name == other.device_name && node_name == other.node_name &&
- output_slot == other.output_slot && debug_op == other.debug_op);
-}
-
-bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
- return !((*this) == other);
-}
-
Status ReadEventFromFile(const string& dump_file_path, Event* event) {
Env* env(Env::Default());
@@ -338,16 +316,9 @@ Status ReadEventFromFile(const string& dump_file_path, Event* event) {
return Status::OK();
}
-const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) {
- return strings::StrCat(
- DebugIO::kMetadataFilePrefix, DebugIO::kDeviceTag,
- str_util::StringReplace(
- str_util::StringReplace(device_name, ":", "_", true), "/", ",",
- true));
-}
-
const char* const DebugIO::kFileURLScheme = "file://";
const char* const DebugIO::kGrpcURLScheme = "grpc://";
+const char* const DebugIO::kMemoryURLScheme = "memcbk://";
// Publishes debug metadata to a set of debug URLs.
Status DebugIO::PublishDebugMetadata(
@@ -423,7 +394,7 @@ Status DebugIO::PublishDebugMetadata(
const string core_metadata_path = AppendTimestampToFilePath(
io::JoinPath(
dump_root_dir,
- strings::StrCat(DebugIO::kMetadataFilePrefix,
+ strings::StrCat(DebugNodeKey::kMetadataFilePrefix,
DebugIO::kCoreMetadataTag, "sessionrun",
strings::Printf("%.14lld", session_run_index))),
Env::Default()->NowMicros());
@@ -465,6 +436,12 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
#else
GRPC_OSS_WINDOWS_UNIMPLEMENTED_ERROR;
#endif
+ } else if (str_util::Lowercase(url).find(kMemoryURLScheme) == 0) {
+ const string dump_root_dir = url.substr(strlen(kMemoryURLScheme));
+ auto* callback_registry = DebugCallbackRegistry::singleton();
+ auto* callback = callback_registry->GetCallback(dump_root_dir);
+ CHECK(callback) << "No callback registered for: " << dump_root_dir;
+ (*callback)(debug_node_key, tensor);
} else {
return Status(error::UNAVAILABLE,
strings::StrCat("Invalid debug target URL: ", url));
@@ -515,7 +492,7 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name,
DebugNodeKey::DeviceNameToDevicePath(device_name));
const uint64 graph_hash = ::tensorflow::Hash64(buf);
const string file_name =
- strings::StrCat(DebugIO::kMetadataFilePrefix, DebugIO::kGraphTag,
+ strings::StrCat(DebugNodeKey::kMetadataFilePrefix, DebugIO::kGraphTag,
DebugIO::kHashTag, graph_hash, "_", now_micros);
status.Update(
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index 75fc2b07f3..023d7a7ee0 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -24,6 +24,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/status.h"
@@ -45,39 +46,18 @@ struct DebugWatchAndURLSpec {
const bool gated_grpc;
};
-struct DebugNodeKey {
- DebugNodeKey(const string& device_name, const string& node_name,
- const int32 output_slot, const string& debug_op);
-
- // Converts a device name string to a device path string.
- // E.g., /job:localhost/replica:0/task:0/cpu:0 will be converted to
- // ,job_localhost,replica_0,task_0,cpu_0.
- static const string DeviceNameToDevicePath(const string& device_name);
-
- bool operator==(const DebugNodeKey& other) const;
- bool operator!=(const DebugNodeKey& other) const;
-
- const string device_name;
- const string node_name;
- const int32 output_slot;
- const string debug_op;
- const string debug_node_name;
- const string device_path;
-};
-
// TODO(cais): Put static functions and members in a namespace, not a class.
class DebugIO {
public:
static const char* const kDebuggerPluginName;
- static const char* const kMetadataFilePrefix;
static const char* const kCoreMetadataTag;
- static const char* const kDeviceTag;
static const char* const kGraphTag;
static const char* const kHashTag;
static const char* const kFileURLScheme;
static const char* const kGrpcURLScheme;
+ static const char* const kMemoryURLScheme;
static Status PublishDebugMetadata(
const int64 global_step, const int64 session_run_index,
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index c0bb65e7f4..2f83c2415b 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include "tensorflow/core/debug/debug_io_utils.h"
+#include "tensorflow/core/debug/debug_callback_registry.h"
+#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -307,6 +309,38 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) {
}
}
+TEST_F(DebugIOUtilsTest, PublishTensorToMemoryCallback) {
+ Initialize();
+
+ const DebugNodeKey kDebugNodeKey("/job:localhost/replica:0/task:0/cpu:0",
+ "foo/bar/qux/tensor_a", 0, "DebugIdentity");
+ const uint64 wall_time = env_->NowMicros();
+
+ bool called = false;
+ std::vector<string> urls = {"memcbk://test_callback"};
+ ;
+
+ auto* callback_registry = DebugCallbackRegistry::singleton();
+ callback_registry->RegisterCallback(
+ "test_callback", [this, &kDebugNodeKey, &called](const DebugNodeKey& key,
+ const Tensor& tensor) {
+ called = true;
+ ASSERT_EQ(kDebugNodeKey.device_name, key.device_name);
+ ASSERT_EQ(kDebugNodeKey.node_name, key.node_name);
+ ASSERT_EQ(tensor_a_->shape(), tensor.shape());
+ for (int i = 0; i < tensor.flat<float>().size(); ++i) {
+ ASSERT_EQ(tensor_a_->flat<float>()(i), tensor.flat<float>()(i));
+ }
+ });
+
+ Status s =
+ DebugIO::PublishDebugTensor(kDebugNodeKey, *tensor_a_, wall_time, urls);
+ ASSERT_TRUE(s.ok());
+ ASSERT_TRUE(called);
+
+ callback_registry->UnregisterCallback("test_callback");
+}
+
TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
Initialize();
diff --git a/tensorflow/core/debug/debug_node_key.cc b/tensorflow/core/debug/debug_node_key.cc
new file mode 100644
index 0000000000..4b56fe8358
--- /dev/null
+++ b/tensorflow/core/debug/debug_node_key.cc
@@ -0,0 +1,53 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/debug/debug_node_key.h"
+
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+const char* const DebugNodeKey::kMetadataFilePrefix = "_tfdbg_";
+
+const char* const DebugNodeKey::kDeviceTag = "device_";
+
+DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name,
+ const int32 output_slot, const string& debug_op)
+ : device_name(device_name),
+ node_name(node_name),
+ output_slot(output_slot),
+ debug_op(debug_op),
+ debug_node_name(
+ strings::StrCat(node_name, ":", output_slot, ":", debug_op)),
+ device_path(DeviceNameToDevicePath(device_name)) {}
+
+bool DebugNodeKey::operator==(const DebugNodeKey& other) const {
+ return (device_name == other.device_name && node_name == other.node_name &&
+ output_slot == other.output_slot && debug_op == other.debug_op);
+}
+
+bool DebugNodeKey::operator!=(const DebugNodeKey& other) const {
+ return !((*this) == other);
+}
+
+const string DebugNodeKey::DeviceNameToDevicePath(const string& device_name) {
+ return strings::StrCat(
+ kMetadataFilePrefix, kDeviceTag,
+ str_util::StringReplace(
+ str_util::StringReplace(device_name, ":", "_", true), "/", ",",
+ true));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_node_key.h b/tensorflow/core/debug/debug_node_key.h
new file mode 100644
index 0000000000..b46054c013
--- /dev/null
+++ b/tensorflow/core/debug/debug_node_key.h
@@ -0,0 +1,51 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_DEBUG_NODE_KEY_H_
+#define TENSORFLOW_DEBUG_NODE_KEY_H_
+
+#include <string>
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Encapsulates debug information for a node that was observed.
+struct DebugNodeKey {
+ static const char* const kMetadataFilePrefix;
+ static const char* const kDeviceTag;
+
+ DebugNodeKey(const string& device_name, const string& node_name,
+ const int32 output_slot, const string& debug_op);
+
+ // Converts a device name string to a device path string.
+ // E.g., /job:localhost/replica:0/task:0/cpu:0 will be converted to
+ // ,job_localhost,replica_0,task_0,cpu_0.
+ static const string DeviceNameToDevicePath(const string& device_name);
+
+ bool operator==(const DebugNodeKey& other) const;
+ bool operator!=(const DebugNodeKey& other) const;
+
+ const string device_name;
+ const string node_name;
+ const int32 output_slot;
+ const string debug_op;
+ const string debug_node_name;
+ const string device_path;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_DEBUG_NODE_KEY_H_