aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-07-13 07:40:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-13 07:44:31 -0700
commit31f1375c01782f53cf12ce2aa4efa584483c14ea (patch)
treea5422d5a1823fb892fb0aceb4272464b134d2ed9
parent1eccfb6e190cdd594c2cd1d3b27323f87614d8bc (diff)
Add GradientsDebugger to tfdbg
to allow retrieval of gradient tensors created by TensorFlow's automatic differentiation algorithm (i.e., tf.gradients and optimizer code that uses it). PiperOrigin-RevId: 161805516
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake12
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake2
-rw-r--r--tensorflow/core/ops/debug_ops.cc3
-rw-r--r--tensorflow/python/debug/BUILD40
-rw-r--r--tensorflow/python/debug/__init__.py5
-rw-r--r--tensorflow/python/debug/lib/debug_gradients.py417
-rw-r--r--tensorflow/python/debug/lib/debug_gradients_test.py378
7 files changed, 857 insertions, 0 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 21d94da6cc..b350b822dd 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -129,3 +129,15 @@ list(REMOVE_ITEM tf_core_ops_srcs ${tf_core_ops_exclude_srcs})
add_library(tf_core_ops OBJECT ${tf_core_ops_srcs})
add_dependencies(tf_core_ops tf_core_cpu)
+
+########################################################
+# tf_debug_ops library
+########################################################
+
+file(GLOB tf_debug_ops_srcs
+ "${tensorflow_source_dir}/tensorflow/core/ops/debug_ops.cc"
+)
+
+add_library(tf_debug_ops OBJECT ${tf_debug_ops_srcs})
+
+add_dependencies(tf_debug_ops tf_core_framework)
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index beb630bdfc..0024efab38 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -692,6 +692,8 @@ GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py)
GENERATE_PYTHON_OP_LIB("stateless_random_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py)
+GENERATE_PYTHON_OP_LIB("debug_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py)
add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES})
add_dependencies(tf_python_ops tf_python_op_gen_main)
diff --git a/tensorflow/core/ops/debug_ops.cc b/tensorflow/core/ops/debug_ops.cc
index f7a96b58da..bd7f7c2c01 100644
--- a/tensorflow/core/ops/debug_ops.cc
+++ b/tensorflow/core/ops/debug_ops.cc
@@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
// This file registers all TensorFlow Debugger (tfdbg) ops.
+#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
@@ -88,6 +90,7 @@ REGISTER_OP("DebugIdentity")
.Attr("debug_urls: list(string) = []")
.Attr("gated_grpc: bool = false")
.SetAllowsUninitializedInput()
+ .SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Debug Identity Op.
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 86d7b5f770..d7ca36b12b 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -23,6 +23,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
py_library(
name = "debug_py",
@@ -31,6 +32,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
":debug_data",
+ ":debug_gradients",
":debug_utils",
":hooks",
":local_cli_wrapper",
@@ -60,6 +62,24 @@ py_library(
],
)
+tf_gen_op_wrapper_py(
+ name = "debug_ops",
+ deps = ["//tensorflow/core:debug_ops_op_lib"],
+)
+
+py_library(
+ name = "debug_gradients",
+ srcs = ["lib/debug_gradients.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":debug_data",
+ ":debug_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:platform",
+ "@six_archive//:six",
+ ],
+)
+
py_library(
name = "debug_utils",
srcs = ["lib/debug_utils.py"],
@@ -381,6 +401,26 @@ py_test(
],
)
+cuda_py_test(
+ name = "debug_gradients_test",
+ size = "small",
+ srcs = [
+ "lib/debug_gradients_test.py",
+ ],
+ additional_deps = [
+ ":debug_data",
+ ":debug_gradients",
+ ":debug_utils",
+ "//tensorflow/python:client",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+)
+
py_test(
name = "debug_utils_test",
size = "small",
diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py
index 147e42d878..e20849bc1c 100644
--- a/tensorflow/python/debug/__init__.py
+++ b/tensorflow/python/debug/__init__.py
@@ -31,6 +31,9 @@ See the @{$python/tfdbg} guide.
@@LocalCLIDebugHook
@@LocalCLIDebugWrapperSession
@@WatchOptions
+
+@@GradientsDebugger
+@@clear_gradient_debuggers
"""
from __future__ import absolute_import
@@ -44,6 +47,8 @@ from tensorflow.python.debug.lib.debug_data import has_inf_or_nan
from tensorflow.python.debug.lib.debug_data import load_tensor_from_event
from tensorflow.python.debug.lib.debug_data import load_tensor_from_event_file
+from tensorflow.python.debug.lib.debug_gradients import GradientsDebugger
+
from tensorflow.python.debug.lib.debug_utils import add_debug_tensor_watch
from tensorflow.python.debug.lib.debug_utils import watch_graph
from tensorflow.python.debug.lib.debug_utils import watch_graph_with_blacklists
diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py
new file mode 100644
index 0000000000..e5159cb941
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_gradients.py
@@ -0,0 +1,417 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow Debugger: Tools for debugging gradients."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+import uuid
+
+import six
+
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.ops import gen_debug_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+
+_GRADIENT_DEBUG_TAG = "gradient_debug_"
+
+_gradient_debuggers = {}
+
+
+def _tensor_to_grad_debug_op_name(tensor, grad_debugger_uuid):
+ op_name, slot = debug_data.parse_node_or_tensor_name(tensor.name)
+ return "%s_%d/%s%s" % (op_name, slot, _GRADIENT_DEBUG_TAG, grad_debugger_uuid)
+
+
+def _parse_grad_debug_op_name(op_name):
+ """Parse the name of a debug gradient op.
+
+ Args:
+ op_name: the name of the debug gradient op.
+
+ Returns:
+ 1) The UUID of the GradientsDebugger that created the debug gradient op.
+ 2) Name of the original tensor whose gradient is debugged by the debug
+ gradient op.
+ """
+ name_items = op_name.split("/")
+ assert len(name_items) > 1
+ assert name_items[-1].startswith(_GRADIENT_DEBUG_TAG)
+
+ grad_debugger_uuid = name_items[-1][len(_GRADIENT_DEBUG_TAG):]
+ if "_" in grad_debugger_uuid:
+ grad_debugger_uuid = grad_debugger_uuid[:grad_debugger_uuid.index("_")]
+ orig_tensor_slot = int(name_items[-2][name_items[-2].rfind("_") + 1:])
+ orig_base_op_name = name_items[-2][:name_items[-2].rfind("_")]
+ orig_tensor_name = ("/".join(name_items[:-2] + [orig_base_op_name]) +
+ ":%d" % orig_tensor_slot)
+
+ return grad_debugger_uuid, orig_tensor_name
+
+
+class GradientsDebugger(object):
+ """Gradients Debugger.
+
+ Allows retrieval of gradient tensors created by TensorFlow's automatic
+ differentiation algorithm, i.e., @{tf.gradients} and optimizer classes that
+ use it.
+ """
+ # TODO(cais): Add examples code in the doc string?
+
+ def __init__(self, y_tensor=None):
+ """Constructor of GradientsDebugger.
+
+ Args:
+ y_tensor: optional: the `tf.Tensor` to be differentiated, i.e., the tensor
+ on the numerator of the differentiation.
+ """
+
+ self._uuid = uuid.uuid4().hex
+ _gradient_debuggers[self._uuid] = self
+
+ # A dict mapping x-tensor names to gradient tensor. x-tensor refers to the
+ # independent tf.Tensor, i.e., the tensor on the denominator of the
+ # differentiation.
+ self._gradient_tensors = {}
+ self._y_tensor = y_tensor
+
+ self._graph = None
+ if y_tensor:
+ self._graph = y_tensor.graph
+
+ self._is_active_context = False
+
+ @property
+ def y_tensor(self):
+ return self._y_tensor
+
+ @property
+ def graph(self):
+ return self._graph
+
+ def __enter__(self):
+ self._is_active_context = True
+
+ def __exit__(self, unused_type, unused_value, unused_traceback):
+ self._is_active_context = False
+
+ def identify_gradient(self, input_tensor):
+ """Create a debug identity tensor that registers and forwards gradients.
+
+ The side effect of this method is that when gradient tensor(s) are created
+ with respect to the any paths that include the `input_tensor`, the gradient
+ tensor(s) with repsect to `input_tensor` will be registered with this
+ this `GradientsDebugger` instance and can later be retrieved, with the
+ methods `gradient_tensor` and `gradient_tensors`.
+
+ Example:
+
+ ```python
+ x = tf.Variable(1.0)
+ y = tf.add(x, x)
+
+ grad_debugger = tf_debug.GradientsDebugger()
+ debug_y = grad_debugger.identify_gradient(y)
+ z = tf.square(debug_y)
+
+ # Create a train op under the grad_debugger context.
+ with grad_debugger:
+ train_op = tf.train.GradientDescentOptimizer(z)
+
+ # Now we can reflect through grad_debugger to get the gradient tensor
+ # with respect to y.
+ y_grad = grad_debugger.gradient_tensor(y)
+ ```
+
+ Args:
+ input_tensor: the input `tf.Tensor` object whose related gradient tensors
+ are to be reigstered with this `GradientsDebugger` instance when they
+ are created, e.g., during @{tf.gradients} calls or the construction
+ of optimization (training) op that uses @{tf.gradients}.
+
+ Returns:
+ A forwarded identity of `input_tensor`, as a `tf.Tensor`.
+
+ Raises:
+ ValueError: If an op with name that duplicates the gradient-debugging op
+ already exists in the graph (highly unlikely).
+ """
+ # TODO(cais): Allow overriding gradient.
+ # TODO(cais): Implement value_stack.
+ grad_debug_op_name = _tensor_to_grad_debug_op_name(input_tensor, self._uuid)
+ debug_identity = gen_debug_ops.debug_identity(
+ input_tensor,
+ tensor_name=input_tensor.name,
+ debug_urls=[],
+ name=grad_debug_op_name)
+ if debug_identity.op.name != grad_debug_op_name:
+ raise ValueError(
+ "The graph already contains an op named %s" % grad_debug_op_name)
+ return debug_identity
+
+ def watch_gradients_by_tensors(self, graph, tensors):
+ """Watch gradient tensors by x-tensor(s).
+
+ The side effect of this method is that when gradient tensor(s) are created
+ with respect to the any paths that include the `x_tensor`s, the gradient
+ tensor(s) with repsect to the tensor will be registered with this
+ this `GradientsDebugger` instance and can later be retrieved, with the
+ methods `gradient_tensor` and `gradient_tensors`.
+
+ Unlike the method `identify_gradient`, this method is used to retrieve
+ gradient tensors after the construction of the forward subgraph has
+ completed (but before the construction of the backward subgraph).
+
+ This method is the same as `watch_gradients_by_x_tensor_names` except that
+ the tensors are specified by the Python `tf.Tensor` or `tf.Variable`
+ objects, instead by name patterns.
+
+ Example:
+
+ ```python
+ x = tf.Variable(1.0)
+ y = tf.add(x, x, name="y")
+ z = tf.square(debug_y)
+
+ # Create a train op under the grad_debugger context.
+ grad_debugger = tf_debug.GradientsDebugger()
+ with grad_debugger.watch_gradients_by_tensors(y):
+ train_op = tf.train.GradientDescentOptimizer(z)
+
+ # Now we can reflect through grad_debugger to get the gradient tensor
+ # with respect to y.
+ y_grad = grad_debugger.gradient_tensor(y)
+ # or
+ y_grad = grad_debugger.gradient_tensor("y:0")
+ ```
+
+ Args:
+ graph: the `tf.Graph` to watch the gradients on.
+ tensors: a `tf.Tensor` or `tf.Variable` object, or a list of such objects.
+
+ Returns:
+ The GradientsDebugger instance itself.
+ """
+
+ if not isinstance(tensors, list):
+ tensors = [tensors]
+
+ tensor_name_regex = []
+ for tensor in tensors:
+ tensor_name_regex.append(re.escape(tensor.name) + "$")
+ tensor_name_regex = "(" + "|".join(tensor_name_regex) + ")"
+ return self.watch_gradients_by_tensor_names(graph, tensor_name_regex)
+
+ def watch_gradients_by_tensor_names(self, graph, tensor_name_regex):
+ """Watch gradient tensors by name(s) of the x-tensor(s).
+
+ The side effect of this method is that when gradient tensor(s) are created
+ with respect to the x-tensors, the gradient tensor(s) will be registered
+ with this `GradientsDebugger` instance and can later be retrieved.
+
+ Unlike the `identify_gradient` method, this method is used after the
+ construction of the forward graph has completed. Unlike the
+ `watch_gradients_by_tensor` method, this method does not use handles to the
+ tensors of interest; it uses their names.
+
+ This method is the same as `watch_gradients_by_tensors` except that the
+ x-tensors are specified by name patterns, instead of `tf.Tensor` or
+ `tf.Variable` objects.
+
+ Example:
+
+ ```python
+ x = tf.Variable(1.0, name="x")
+ y = tf.add(x, x, name="y")
+ z = tf.square(debug_y)
+
+ # Create a train op under the grad_debugger context.
+ grad_debugger = tf_debug.GradientsDebugger()
+ with grad_debugger.watch_gradients_by_tensor_names(r"(x|y):0$"):
+ train_op = tf.train.GradientDescentOptimizer(z)
+
+ # Now we can reflect through grad_debugger to get the gradient tensor
+ # with respect to x and y.
+ x_grad = grad_debugger.gradient_tensor("x:0")
+ y_grad = grad_debugger.gradient_tensor("y:0")
+ ```
+
+ Args:
+ graph: the `tf.Graph` to watch the gradients on.
+ tensor_name_regex: the regular-expression pattern of the name(s) of the
+ x-tensor(s) to watch. x-tensor refers to the tensors on the denominator
+ of the differentiation.
+
+ Returns:
+ The GradientsDebugger instance itself.
+ """
+ tensor_name_pattern = re.compile(tensor_name_regex)
+
+ # pylint: disable=protected-access
+ with graph.as_default():
+ for op in graph.get_operations():
+ for output in op.outputs:
+ if tensor_name_pattern.match(output.name):
+ debug_op = self.identify_gradient(output)
+
+ for consumer in output.consumers():
+ if consumer == debug_op.op:
+ continue
+
+ # Locate the slot index of the original input.
+ input_slots = []
+ for i, consumer_input in enumerate(consumer._inputs):
+ if consumer_input == output:
+ input_slots.append(i)
+
+ for slot in input_slots:
+ consumer._inputs[slot] = debug_op
+ debug_op._consumers.append(consumer)
+
+ del output._consumers[:]
+ output._consumers.append(debug_op.op)
+ # pylint: enable=protected-access
+
+ return self
+
+ def _check_same_graph(self, tensor):
+ if self._graph is None:
+ self._graph = tensor.graph
+ elif self._graph != tensor.graph:
+ raise ValueError(
+ "The graph of the value (%s) is not the same as the graph %s" %
+ (tensor.graph, self._graph))
+
+ def register_gradient_tensor(self,
+ x_tensor_name,
+ gradient_tensor):
+ """Register the gradient tensor for an x-tensor.
+
+ Args:
+ x_tensor_name: (`str`) the name of the the independent `tf.Tensor`, i.e.,
+ the tensor on the denominator of the differentiation.
+ gradient_tensor: the gradient `tf.Tensor`.
+ """
+ if len(_gradient_debuggers) == 1 or self._is_active_context:
+ self._check_same_graph(gradient_tensor)
+ self._gradient_tensors[x_tensor_name] = gradient_tensor
+
+ def gradient_tensor(self, x_tensor):
+ """Get the gradient tensor of an x-tensor.
+
+ Args:
+ x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its
+ name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor
+ on the denominator of the differentiation.
+
+ Returns:
+ If found, the gradient tensor.
+
+ Raises:
+ TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`.
+ LookupError: If the `x_tensor` has not been registered with a gradient
+ tensor.
+ """
+ x_tensor_name = self._get_tensor_name(x_tensor)
+ if x_tensor_name not in self._gradient_tensors:
+ raise LookupError(
+ "This GradientsDebugger has not received any gradient tensor for "
+ "x-tensor %s" % x_tensor_name)
+ return self._gradient_tensors[x_tensor_name]
+
+ def gradient_tensors(self):
+ """Get the gradient tensors that this object is aware of.
+
+ Returns:
+ A dict mapping x-tensor names to gradient tensor objects. x-tensor refers
+ to the tensors on the denominator of the differentation.
+ """
+ return self._gradient_tensors
+
+ def _get_tensor_name(self, tensor):
+ if isinstance(tensor, (ops.Tensor, variables.Variable)):
+ return tensor.name
+ elif isinstance(tensor, six.string_types):
+ return tensor
+ else:
+ raise TypeError(
+ "x_tensor must be a str or tf.Tensor or tf.Variable, "
+ "but instead has type %s" % type(tensor))
+
+
+def clear_gradient_debuggers():
+ """Clear all globally registered gradient debuggers."""
+ _gradient_debuggers.clear()
+
+
+@ops.RegisterGradient("DebugIdentity")
+def _identify_gradient_grad(op, dy):
+ """Gradient function for the DebugIdentity op."""
+ # TODO(cais): Allow overriding gradient.
+ grad_debugger_uuid, orig_tensor_name = _parse_grad_debug_op_name(op.name)
+ grad_debugger = _gradient_debuggers[grad_debugger_uuid]
+ grad_debugger.register_gradient_tensor(orig_tensor_name, dy)
+ return dy
+
+
+def gradient_values_from_dump(grad_debugger, x_tensor, dump):
+ """Find gradient values from a `DebugDumpDir` object.
+
+ Args:
+ grad_debugger: the `tf_debug.GradientsDebugger` instance to be used.
+ x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its
+ name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor
+ on the denominator of the differentiation.
+ dump: A `tfdbg.DebugDumpDir` object.
+
+ Returns:
+ If this `GradientsDebugger` instance has the gradient tensor of `x_tensor`
+ registered: a list of `numpy.ndarray` representing the value of the
+ gradient tensor from `dump`. The list could be empty, if the gradient
+ tensor is not executed in the `tf.Session.run()` call that generated
+ the `dump`. The list could also contain multiple values of the gradient
+ tensor, e.g., if gradient tensor is computed repeatedly in a
+ `tf.while_loop` during the run that generated the `dump`.
+
+ Raises:
+ LookupError: If this `GradientsDebugger` instance does not have the
+ gradient tensor of `x_tensor` registered.
+ ValueError: If this `GradientsDebugger` has a `tf.Graph` object that
+ does not match the `tf.Graph` object of the `dump`.
+ TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`.
+ """
+ # TODO(cais): Use this method in LocalCLIDebugWrapperSession to present the
+ # gradient tensors to the TFDBG CLI.
+
+ # If possible, verify that the Python graph of the dump and that of this
+ # GradientsDebugger match.
+ if (dump.python_graph and grad_debugger.graph and
+ dump.python_graph != grad_debugger.graph):
+ raise ValueError(
+ "This GradientsDebugger instance has a graph (%s) that differs from "
+ "the graph of the DebugDumpDir object (%s)." %
+ (grad_debugger.graph, dump.python_graph))
+
+ gradient_tensor = grad_debugger.gradient_tensor(x_tensor)
+ node_name, output_slot = debug_data.parse_node_or_tensor_name(
+ gradient_tensor.name)
+
+ try:
+ return dump.get_tensors(node_name, output_slot, "DebugIdentity")
+ except debug_data.WatchKeyDoesNotExistInDebugDumpDirError:
+ return []
diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py
new file mode 100644
index 0000000000..966578320e
--- /dev/null
+++ b/tensorflow/python/debug/lib/debug_gradients_test.py
@@ -0,0 +1,378 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for debug_gradients module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.debug.lib import debug_data
+from tensorflow.python.debug.lib import debug_gradients
+from tensorflow.python.debug.lib import debug_utils
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+from tensorflow.python.training import gradient_descent
+
+
+class IdentifyGradientTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self.sess = session.Session()
+ with self.sess:
+ self.u = variables.Variable(2.0, name="u")
+ self.v = variables.Variable(3.0, name="v")
+ self.w = math_ops.multiply(self.u.value(), self.v.value(), name="w")
+
+ def tearDown(self):
+ ops.reset_default_graph()
+ debug_gradients.clear_gradient_debuggers()
+
+ def testIdentifyGradientGivesCorrectTensorObjectWithoutContextManager(self):
+ grad_debugger = debug_gradients.GradientsDebugger()
+ id_grad_w = grad_debugger.identify_gradient(self.w)
+ y = math_ops.add(id_grad_w, -1.0, name="y")
+
+ grads = gradients_impl.gradients(y, [self.u, self.v])
+ self.assertEqual(2, len(grads))
+ u_grad = grads[0]
+ v_grad = grads[1]
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(5.0, self.sess.run(y))
+ self.assertAllClose(3.0, self.sess.run(u_grad))
+ self.assertAllClose(2.0, self.sess.run(v_grad))
+
+ # Fetch the gradient tensor with the x-tensor object.
+ w_grad = grad_debugger.gradient_tensor(self.w)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ # Fetch the gradient tensor with the x-tensor's name.
+ w_grad = grad_debugger.gradient_tensor(self.w.name)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ # Fetch the gradient tensor with the x-tensor name.
+ w_grad = grad_debugger.gradient_tensor(self.w.name)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self):
+ grad_debugger = debug_gradients.GradientsDebugger()
+ id_grad_w = grad_debugger.identify_gradient(self.w)
+ y = math_ops.add(id_grad_w, -1.0, name="y")
+
+ with grad_debugger:
+ grads = gradients_impl.gradients(y, [self.u, self.v])
+ self.assertEqual(2, len(grads))
+ u_grad = grads[0]
+ v_grad = grads[1]
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(5.0, self.sess.run(y))
+ self.assertAllClose(3.0, self.sess.run(u_grad))
+ self.assertAllClose(2.0, self.sess.run(v_grad))
+
+ # Fetch the gradient tensor with the x-tensor object.
+ w_grad = grad_debugger.gradient_tensor(self.w)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ # Fetch the gradient tensor with the x-tensor's name.
+ w_grad = grad_debugger.gradient_tensor(self.w.name)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ # Fetch the gradient tensor with the x-tensor name.
+ w_grad = grad_debugger.gradient_tensor(self.w.name)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self):
+ grad_debugger = debug_gradients.GradientsDebugger()
+ grad_debugger.identify_gradient(self.w)
+ with self.assertRaisesRegexp(
+ ValueError, "The graph already contains an op named .*"):
+ grad_debugger.identify_gradient(self.w)
+
+ def testIdentifyGradientWorksOnMultipleLosses(self):
+ grad_debugger_1 = debug_gradients.GradientsDebugger()
+ grad_debugger_2 = debug_gradients.GradientsDebugger()
+
+ y = math_ops.add(self.w, -1.0, name="y")
+ debug_y = grad_debugger_1.identify_gradient(y)
+ z1 = math_ops.square(debug_y, name="z1")
+
+ debug_y = grad_debugger_2.identify_gradient(y)
+ z2 = math_ops.sqrt(debug_y, name="z2")
+
+ with grad_debugger_1:
+ gradient_descent.GradientDescentOptimizer(0.1).minimize(z1)
+ with grad_debugger_2:
+ gradient_descent.GradientDescentOptimizer(0.1).minimize(z2)
+
+ dz1_dy = grad_debugger_1.gradient_tensor(y)
+ dz2_dy = grad_debugger_2.gradient_tensor(y)
+ self.assertIsInstance(dz1_dy, ops.Tensor)
+ self.assertIsInstance(dz2_dy, ops.Tensor)
+ self.assertIsNot(dz1_dy, dz2_dy)
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(5.0 ** 2, self.sess.run(z1))
+ self.assertAllClose(5.0 ** 0.5, self.sess.run(z2))
+ self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
+ self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy))
+
+ def testIdentifyGradientRaisesLookupErrorForUnknownXTensor(self):
+ grad_debugger_1 = debug_gradients.GradientsDebugger()
+ grad_debugger_2 = debug_gradients.GradientsDebugger()
+ id_grad_w = grad_debugger_1.identify_gradient(self.w)
+ y = math_ops.add(id_grad_w, -1.0, name="y")
+
+ # There are >1 gradient debuggers registered, and grad_debugger is not used
+ # as a context manager here, so the gradient w.r.t. self.w will not be
+ # registered.
+ gradients_impl.gradients(y, [self.u, self.v])
+
+ with self.assertRaisesRegexp(
+ LookupError,
+ r"This GradientsDebugger has not received any gradient tensor for "):
+ grad_debugger_1.gradient_tensor(self.w)
+ with self.assertRaisesRegexp(
+ LookupError,
+ r"This GradientsDebugger has not received any gradient tensor for "):
+ grad_debugger_2.gradient_tensor(self.w)
+
+ def testIdentifyGradientRaisesTypeErrorForNonTensorOrTensorNameInput(self):
+ grad_debugger = debug_gradients.GradientsDebugger()
+ with self.assertRaisesRegexp(
+ TypeError,
+ r"x_tensor must be a str or tf\.Tensor or tf\.Variable, but instead "
+ r"has type .*Operation.*"):
+ grad_debugger.gradient_tensor(variables.global_variables_initializer())
+
+ def testIdentifyGradientTensorWorksWithGradientDescentOptimizer(self):
+ grad_debugger = debug_gradients.GradientsDebugger()
+ id_grad_w = grad_debugger.identify_gradient(self.w)
+ y = math_ops.add(id_grad_w, -1.0, name="y")
+
+ with grad_debugger:
+ gradient_descent.GradientDescentOptimizer(0.1).minimize(y)
+
+ self.sess.run(variables.global_variables_initializer())
+
+ # Fetch the gradient tensor with the x-tensor object.
+ w_grad = grad_debugger.gradient_tensor(self.w)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ def testWatchGradientsByXTensorNamesWorks(self):
+ y = math_ops.add(self.w, -1.0, name="y")
+
+ # The constructrion of the forward graph has completed.
+ # But we can still get the gradient tensors by using
+ # watch_gradients_by_tensor_names().
+ grad_debugger = debug_gradients.GradientsDebugger()
+ with grad_debugger.watch_gradients_by_tensor_names(self.sess.graph, "w:0$"):
+ grads = gradients_impl.gradients(y, [self.u, self.v])
+ self.assertEqual(2, len(grads))
+ u_grad = grads[0]
+ v_grad = grads[1]
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(5.0, self.sess.run(y))
+ self.assertAllClose(3.0, self.sess.run(u_grad))
+ self.assertAllClose(2.0, self.sess.run(v_grad))
+
+ w_grad = grad_debugger.gradient_tensor(self.w)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ w_grad = grad_debugger.gradient_tensor("w:0")
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self):
+ y = math_ops.add(self.w, -1.0, name="y")
+
+ # The constructrion of the forward graph has completed.
+ # But we can still get the gradient tensors by using
+ # watch_gradients_by_tensor_names().
+ grad_debugger = debug_gradients.GradientsDebugger()
+ grad_debugger.watch_gradients_by_tensor_names(self.sess.graph, "w:0$")
+ grads = gradients_impl.gradients(y, [self.u, self.v])
+ self.assertEqual(2, len(grads))
+ u_grad = grads[0]
+ v_grad = grads[1]
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(5.0, self.sess.run(y))
+ self.assertAllClose(3.0, self.sess.run(u_grad))
+ self.assertAllClose(2.0, self.sess.run(v_grad))
+
+ w_grad = grad_debugger.gradient_tensor(self.w)
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ w_grad = grad_debugger.gradient_tensor("w:0")
+ self.assertIsInstance(w_grad, ops.Tensor)
+ self.assertAllClose(1.0, self.sess.run(w_grad))
+
+ def testWatchGradientsWorksOnRefTensor(self):
+ y = math_ops.add(self.w, -1.0, name="y")
+
+ grad_debugger = debug_gradients.GradientsDebugger()
+ with grad_debugger.watch_gradients_by_tensor_names(self.sess.graph, "u:0$"):
+ grads = gradients_impl.gradients(y, [self.u, self.v])
+ self.assertEqual(2, len(grads))
+ u_grad = grads[0]
+ v_grad = grads[1]
+
+ self.assertIs(u_grad, grad_debugger.gradient_tensor("u:0"))
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(3.0, self.sess.run(u_grad))
+ self.assertAllClose(2.0, self.sess.run(v_grad))
+ self.assertAllClose(
+ 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0")))
+
+ def testWatchGradientsWorksOnMultipleTensors(self):
+ y = math_ops.add(self.w, -1.0, name="y")
+
+ grad_debugger = debug_gradients.GradientsDebugger()
+ with grad_debugger.watch_gradients_by_tensor_names(self.sess.graph,
+ "(u|w):0$"):
+ grads = gradients_impl.gradients(y, [self.u, self.v])
+ self.assertEqual(2, len(grads))
+ u_grad = grads[0]
+
+ self.assertEqual(2, len(grad_debugger.gradient_tensors()))
+ self.assertIs(u_grad, grad_debugger.gradient_tensor("u:0"))
+ self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), ops.Tensor)
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(
+ 1.0, self.sess.run(grad_debugger.gradient_tensor("w:0")))
+ self.assertAllClose(
+ 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0")))
+
+ def testWatchGradientsByXTensorsWorks(self):
+ y = math_ops.add(self.w, -1.0, name="foo/y")
+ z = math_ops.square(y, name="foo/z")
+
+ # The constructrion of the forward graph has completed.
+ # But we can still get the gradient tensors by using
+ # watch_gradients_by_x_tensors().
+ grad_debugger = debug_gradients.GradientsDebugger()
+ with grad_debugger.watch_gradients_by_tensors(
+ self.sess.graph, [self.w, self.u, y]):
+ gradient_descent.GradientDescentOptimizer(0.1).minimize(z)
+
+ self.assertEqual(3, len(grad_debugger.gradient_tensors()))
+ u_grad = grad_debugger.gradient_tensor(self.u)
+ w_grad = grad_debugger.gradient_tensor(self.w)
+ y_grad = grad_debugger.gradient_tensor(y)
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(10.0, self.sess.run(y_grad))
+ self.assertAllClose(10.0, self.sess.run(w_grad))
+ self.assertAllClose(30.0, self.sess.run(u_grad))
+
+ def testWatchGradientsByTensorCanWorkOnMultipleLosses(self):
+ y = math_ops.add(self.w, -1.0, name="y")
+ z1 = math_ops.square(y, name="z1")
+ z2 = math_ops.sqrt(y, name="z2")
+
+ grad_debugger_1 = debug_gradients.GradientsDebugger()
+ with grad_debugger_1.watch_gradients_by_tensors(self.sess.graph, y):
+ gradient_descent.GradientDescentOptimizer(0.1).minimize(z1)
+
+ grad_debugger_2 = debug_gradients.GradientsDebugger()
+ with grad_debugger_2.watch_gradients_by_tensors(self.sess.graph, y):
+ gradient_descent.GradientDescentOptimizer(0.1).minimize(z2)
+
+ dz1_dy = grad_debugger_1.gradient_tensor(y)
+ dz2_dy = grad_debugger_2.gradient_tensor(y)
+ self.assertIsInstance(dz1_dy, ops.Tensor)
+ self.assertIsInstance(dz2_dy, ops.Tensor)
+ self.assertIsNot(dz1_dy, dz2_dy)
+
+ self.sess.run(variables.global_variables_initializer())
+ self.assertAllClose(5.0 ** 2, self.sess.run(z1))
+ self.assertAllClose(5.0 ** 0.5, self.sess.run(z2))
+ self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
+ self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy))
+
+ def testGradientsValuesFromDumpWorks(self):
+ y = math_ops.add(self.w, -1.0, name="y")
+ z = math_ops.square(y, name="z")
+
+ grad_debugger = debug_gradients.GradientsDebugger()
+ with grad_debugger.watch_gradients_by_tensors(
+ self.sess.graph, [self.w, self.u, y]):
+ train_op = gradient_descent.GradientDescentOptimizer(0.1).minimize(z)
+
+ self.sess.run(variables.global_variables_initializer())
+
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ dump_dir = tempfile.mkdtemp()
+ debug_url = "file://" + dump_dir
+ debug_utils.watch_graph(
+ run_options,
+ self.sess.graph,
+ debug_urls=debug_url)
+ run_metadata = config_pb2.RunMetadata()
+ self.sess.run(train_op, options=run_options, run_metadata=run_metadata)
+
+ dump = debug_data.DebugDumpDir(
+ dump_dir, partition_graphs=run_metadata.partition_graphs)
+ dump.set_python_graph(self.sess.graph)
+
+ y_grad_values = debug_gradients.gradient_values_from_dump(
+ grad_debugger, y, dump)
+ self.assertEqual(1, len(y_grad_values))
+ self.assertAllClose(10.0, y_grad_values[0])
+
+ w_grad_values = debug_gradients.gradient_values_from_dump(
+ grad_debugger, self.w, dump)
+ self.assertEqual(1, len(w_grad_values))
+ self.assertAllClose(10.0, w_grad_values[0])
+
+ u_grad_values = debug_gradients.gradient_values_from_dump(
+ grad_debugger, self.u, dump)
+ self.assertEqual(1, len(u_grad_values))
+ self.assertAllClose(30.0, u_grad_values[0])
+
+ with self.assertRaisesRegexp(
+ LookupError,
+ r"This GradientsDebugger has not received any gradient tensor for "
+ r"x-tensor v:0"):
+ debug_gradients.gradient_values_from_dump(grad_debugger, self.v, dump)
+
+ # Cleanup.
+ shutil.rmtree(dump_dir)
+
+
+if __name__ == "__main__":
+ googletest.main()