From 31f1375c01782f53cf12ce2aa4efa584483c14ea Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Thu, 13 Jul 2017 07:40:00 -0700 Subject: Add GradientsDebugger to tfdbg to allow retrieval of gradient tensors created by TensorFlow's automatic differentiation algorithm (i.e., tf.gradients and optimizer code that uses it). PiperOrigin-RevId: 161805516 --- tensorflow/contrib/cmake/tf_core_ops.cmake | 12 + tensorflow/contrib/cmake/tf_python.cmake | 2 + tensorflow/core/ops/debug_ops.cc | 3 + tensorflow/python/debug/BUILD | 40 ++ tensorflow/python/debug/__init__.py | 5 + tensorflow/python/debug/lib/debug_gradients.py | 417 +++++++++++++++++++++ .../python/debug/lib/debug_gradients_test.py | 378 +++++++++++++++++++ 7 files changed, 857 insertions(+) create mode 100644 tensorflow/python/debug/lib/debug_gradients.py create mode 100644 tensorflow/python/debug/lib/debug_gradients_test.py diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 21d94da6cc..b350b822dd 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -129,3 +129,15 @@ list(REMOVE_ITEM tf_core_ops_srcs ${tf_core_ops_exclude_srcs}) add_library(tf_core_ops OBJECT ${tf_core_ops_srcs}) add_dependencies(tf_core_ops tf_core_cpu) + +######################################################## +# tf_debug_ops library +######################################################## + +file(GLOB tf_debug_ops_srcs + "${tensorflow_source_dir}/tensorflow/core/ops/debug_ops.cc" +) + +add_library(tf_debug_ops OBJECT ${tf_debug_ops_srcs}) + +add_dependencies(tf_debug_ops tf_core_framework) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index beb630bdfc..0024efab38 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -692,6 +692,8 @@ GENERATE_PYTHON_OP_LIB("contrib_bigquery_reader_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cloud/python/ops/gen_bigquery_reader_ops.py) GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) +GENERATE_PYTHON_OP_LIB("debug_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) diff --git a/tensorflow/core/ops/debug_ops.cc b/tensorflow/core/ops/debug_ops.cc index f7a96b58da..bd7f7c2c01 100644 --- a/tensorflow/core/ops/debug_ops.cc +++ b/tensorflow/core/ops/debug_ops.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ // This file registers all TensorFlow Debugger (tfdbg) ops. +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { @@ -88,6 +90,7 @@ REGISTER_OP("DebugIdentity") .Attr("debug_urls: list(string) = []") .Attr("gated_grpc: bool = false") .SetAllowsUninitializedInput() + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Debug Identity Op. diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 86d7b5f770..d7ca36b12b 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -23,6 +23,7 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "if_not_windows") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") py_library( name = "debug_py", @@ -31,6 +32,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":debug_data", + ":debug_gradients", ":debug_utils", ":hooks", ":local_cli_wrapper", @@ -60,6 +62,24 @@ py_library( ], ) +tf_gen_op_wrapper_py( + name = "debug_ops", + deps = ["//tensorflow/core:debug_ops_op_lib"], +) + +py_library( + name = "debug_gradients", + srcs = ["lib/debug_gradients.py"], + srcs_version = "PY2AND3", + deps = [ + ":debug_data", + ":debug_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform", + "@six_archive//:six", + ], +) + py_library( name = "debug_utils", srcs = ["lib/debug_utils.py"], @@ -381,6 +401,26 @@ py_test( ], ) +cuda_py_test( + name = "debug_gradients_test", + size = "small", + srcs = [ + "lib/debug_gradients_test.py", + ], + additional_deps = [ + ":debug_data", + ":debug_gradients", + ":debug_utils", + "//tensorflow/python:client", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], +) + py_test( name = "debug_utils_test", size = "small", diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py index 147e42d878..e20849bc1c 100644 --- a/tensorflow/python/debug/__init__.py +++ b/tensorflow/python/debug/__init__.py @@ -31,6 +31,9 @@ See the @{$python/tfdbg} guide. @@LocalCLIDebugHook @@LocalCLIDebugWrapperSession @@WatchOptions + +@@GradientsDebugger +@@clear_gradient_debuggers """ from __future__ import absolute_import @@ -44,6 +47,8 @@ from tensorflow.python.debug.lib.debug_data import has_inf_or_nan from tensorflow.python.debug.lib.debug_data import load_tensor_from_event from tensorflow.python.debug.lib.debug_data import load_tensor_from_event_file +from tensorflow.python.debug.lib.debug_gradients import GradientsDebugger + from tensorflow.python.debug.lib.debug_utils import add_debug_tensor_watch from tensorflow.python.debug.lib.debug_utils import watch_graph from tensorflow.python.debug.lib.debug_utils import watch_graph_with_blacklists diff --git a/tensorflow/python/debug/lib/debug_gradients.py b/tensorflow/python/debug/lib/debug_gradients.py new file mode 100644 index 0000000000..e5159cb941 --- /dev/null +++ b/tensorflow/python/debug/lib/debug_gradients.py @@ -0,0 +1,417 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow Debugger: Tools for debugging gradients.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import uuid + +import six + +from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.ops import gen_debug_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables + +_GRADIENT_DEBUG_TAG = "gradient_debug_" + +_gradient_debuggers = {} + + +def _tensor_to_grad_debug_op_name(tensor, grad_debugger_uuid): + op_name, slot = debug_data.parse_node_or_tensor_name(tensor.name) + return "%s_%d/%s%s" % (op_name, slot, _GRADIENT_DEBUG_TAG, grad_debugger_uuid) + + +def _parse_grad_debug_op_name(op_name): + """Parse the name of a debug gradient op. + + Args: + op_name: the name of the debug gradient op. + + Returns: + 1) The UUID of the GradientsDebugger that created the debug gradient op. + 2) Name of the original tensor whose gradient is debugged by the debug + gradient op. + """ + name_items = op_name.split("/") + assert len(name_items) > 1 + assert name_items[-1].startswith(_GRADIENT_DEBUG_TAG) + + grad_debugger_uuid = name_items[-1][len(_GRADIENT_DEBUG_TAG):] + if "_" in grad_debugger_uuid: + grad_debugger_uuid = grad_debugger_uuid[:grad_debugger_uuid.index("_")] + orig_tensor_slot = int(name_items[-2][name_items[-2].rfind("_") + 1:]) + orig_base_op_name = name_items[-2][:name_items[-2].rfind("_")] + orig_tensor_name = ("/".join(name_items[:-2] + [orig_base_op_name]) + + ":%d" % orig_tensor_slot) + + return grad_debugger_uuid, orig_tensor_name + + +class GradientsDebugger(object): + """Gradients Debugger. + + Allows retrieval of gradient tensors created by TensorFlow's automatic + differentiation algorithm, i.e., @{tf.gradients} and optimizer classes that + use it. + """ + # TODO(cais): Add examples code in the doc string? + + def __init__(self, y_tensor=None): + """Constructor of GradientsDebugger. + + Args: + y_tensor: optional: the `tf.Tensor` to be differentiated, i.e., the tensor + on the numerator of the differentiation. + """ + + self._uuid = uuid.uuid4().hex + _gradient_debuggers[self._uuid] = self + + # A dict mapping x-tensor names to gradient tensor. x-tensor refers to the + # independent tf.Tensor, i.e., the tensor on the denominator of the + # differentiation. + self._gradient_tensors = {} + self._y_tensor = y_tensor + + self._graph = None + if y_tensor: + self._graph = y_tensor.graph + + self._is_active_context = False + + @property + def y_tensor(self): + return self._y_tensor + + @property + def graph(self): + return self._graph + + def __enter__(self): + self._is_active_context = True + + def __exit__(self, unused_type, unused_value, unused_traceback): + self._is_active_context = False + + def identify_gradient(self, input_tensor): + """Create a debug identity tensor that registers and forwards gradients. + + The side effect of this method is that when gradient tensor(s) are created + with respect to the any paths that include the `input_tensor`, the gradient + tensor(s) with repsect to `input_tensor` will be registered with this + this `GradientsDebugger` instance and can later be retrieved, with the + methods `gradient_tensor` and `gradient_tensors`. + + Example: + + ```python + x = tf.Variable(1.0) + y = tf.add(x, x) + + grad_debugger = tf_debug.GradientsDebugger() + debug_y = grad_debugger.identify_gradient(y) + z = tf.square(debug_y) + + # Create a train op under the grad_debugger context. + with grad_debugger: + train_op = tf.train.GradientDescentOptimizer(z) + + # Now we can reflect through grad_debugger to get the gradient tensor + # with respect to y. + y_grad = grad_debugger.gradient_tensor(y) + ``` + + Args: + input_tensor: the input `tf.Tensor` object whose related gradient tensors + are to be reigstered with this `GradientsDebugger` instance when they + are created, e.g., during @{tf.gradients} calls or the construction + of optimization (training) op that uses @{tf.gradients}. + + Returns: + A forwarded identity of `input_tensor`, as a `tf.Tensor`. + + Raises: + ValueError: If an op with name that duplicates the gradient-debugging op + already exists in the graph (highly unlikely). + """ + # TODO(cais): Allow overriding gradient. + # TODO(cais): Implement value_stack. + grad_debug_op_name = _tensor_to_grad_debug_op_name(input_tensor, self._uuid) + debug_identity = gen_debug_ops.debug_identity( + input_tensor, + tensor_name=input_tensor.name, + debug_urls=[], + name=grad_debug_op_name) + if debug_identity.op.name != grad_debug_op_name: + raise ValueError( + "The graph already contains an op named %s" % grad_debug_op_name) + return debug_identity + + def watch_gradients_by_tensors(self, graph, tensors): + """Watch gradient tensors by x-tensor(s). + + The side effect of this method is that when gradient tensor(s) are created + with respect to the any paths that include the `x_tensor`s, the gradient + tensor(s) with repsect to the tensor will be registered with this + this `GradientsDebugger` instance and can later be retrieved, with the + methods `gradient_tensor` and `gradient_tensors`. + + Unlike the method `identify_gradient`, this method is used to retrieve + gradient tensors after the construction of the forward subgraph has + completed (but before the construction of the backward subgraph). + + This method is the same as `watch_gradients_by_x_tensor_names` except that + the tensors are specified by the Python `tf.Tensor` or `tf.Variable` + objects, instead by name patterns. + + Example: + + ```python + x = tf.Variable(1.0) + y = tf.add(x, x, name="y") + z = tf.square(debug_y) + + # Create a train op under the grad_debugger context. + grad_debugger = tf_debug.GradientsDebugger() + with grad_debugger.watch_gradients_by_tensors(y): + train_op = tf.train.GradientDescentOptimizer(z) + + # Now we can reflect through grad_debugger to get the gradient tensor + # with respect to y. + y_grad = grad_debugger.gradient_tensor(y) + # or + y_grad = grad_debugger.gradient_tensor("y:0") + ``` + + Args: + graph: the `tf.Graph` to watch the gradients on. + tensors: a `tf.Tensor` or `tf.Variable` object, or a list of such objects. + + Returns: + The GradientsDebugger instance itself. + """ + + if not isinstance(tensors, list): + tensors = [tensors] + + tensor_name_regex = [] + for tensor in tensors: + tensor_name_regex.append(re.escape(tensor.name) + "$") + tensor_name_regex = "(" + "|".join(tensor_name_regex) + ")" + return self.watch_gradients_by_tensor_names(graph, tensor_name_regex) + + def watch_gradients_by_tensor_names(self, graph, tensor_name_regex): + """Watch gradient tensors by name(s) of the x-tensor(s). + + The side effect of this method is that when gradient tensor(s) are created + with respect to the x-tensors, the gradient tensor(s) will be registered + with this `GradientsDebugger` instance and can later be retrieved. + + Unlike the `identify_gradient` method, this method is used after the + construction of the forward graph has completed. Unlike the + `watch_gradients_by_tensor` method, this method does not use handles to the + tensors of interest; it uses their names. + + This method is the same as `watch_gradients_by_tensors` except that the + x-tensors are specified by name patterns, instead of `tf.Tensor` or + `tf.Variable` objects. + + Example: + + ```python + x = tf.Variable(1.0, name="x") + y = tf.add(x, x, name="y") + z = tf.square(debug_y) + + # Create a train op under the grad_debugger context. + grad_debugger = tf_debug.GradientsDebugger() + with grad_debugger.watch_gradients_by_tensor_names(r"(x|y):0$"): + train_op = tf.train.GradientDescentOptimizer(z) + + # Now we can reflect through grad_debugger to get the gradient tensor + # with respect to x and y. + x_grad = grad_debugger.gradient_tensor("x:0") + y_grad = grad_debugger.gradient_tensor("y:0") + ``` + + Args: + graph: the `tf.Graph` to watch the gradients on. + tensor_name_regex: the regular-expression pattern of the name(s) of the + x-tensor(s) to watch. x-tensor refers to the tensors on the denominator + of the differentiation. + + Returns: + The GradientsDebugger instance itself. + """ + tensor_name_pattern = re.compile(tensor_name_regex) + + # pylint: disable=protected-access + with graph.as_default(): + for op in graph.get_operations(): + for output in op.outputs: + if tensor_name_pattern.match(output.name): + debug_op = self.identify_gradient(output) + + for consumer in output.consumers(): + if consumer == debug_op.op: + continue + + # Locate the slot index of the original input. + input_slots = [] + for i, consumer_input in enumerate(consumer._inputs): + if consumer_input == output: + input_slots.append(i) + + for slot in input_slots: + consumer._inputs[slot] = debug_op + debug_op._consumers.append(consumer) + + del output._consumers[:] + output._consumers.append(debug_op.op) + # pylint: enable=protected-access + + return self + + def _check_same_graph(self, tensor): + if self._graph is None: + self._graph = tensor.graph + elif self._graph != tensor.graph: + raise ValueError( + "The graph of the value (%s) is not the same as the graph %s" % + (tensor.graph, self._graph)) + + def register_gradient_tensor(self, + x_tensor_name, + gradient_tensor): + """Register the gradient tensor for an x-tensor. + + Args: + x_tensor_name: (`str`) the name of the the independent `tf.Tensor`, i.e., + the tensor on the denominator of the differentiation. + gradient_tensor: the gradient `tf.Tensor`. + """ + if len(_gradient_debuggers) == 1 or self._is_active_context: + self._check_same_graph(gradient_tensor) + self._gradient_tensors[x_tensor_name] = gradient_tensor + + def gradient_tensor(self, x_tensor): + """Get the gradient tensor of an x-tensor. + + Args: + x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its + name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor + on the denominator of the differentiation. + + Returns: + If found, the gradient tensor. + + Raises: + TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`. + LookupError: If the `x_tensor` has not been registered with a gradient + tensor. + """ + x_tensor_name = self._get_tensor_name(x_tensor) + if x_tensor_name not in self._gradient_tensors: + raise LookupError( + "This GradientsDebugger has not received any gradient tensor for " + "x-tensor %s" % x_tensor_name) + return self._gradient_tensors[x_tensor_name] + + def gradient_tensors(self): + """Get the gradient tensors that this object is aware of. + + Returns: + A dict mapping x-tensor names to gradient tensor objects. x-tensor refers + to the tensors on the denominator of the differentation. + """ + return self._gradient_tensors + + def _get_tensor_name(self, tensor): + if isinstance(tensor, (ops.Tensor, variables.Variable)): + return tensor.name + elif isinstance(tensor, six.string_types): + return tensor + else: + raise TypeError( + "x_tensor must be a str or tf.Tensor or tf.Variable, " + "but instead has type %s" % type(tensor)) + + +def clear_gradient_debuggers(): + """Clear all globally registered gradient debuggers.""" + _gradient_debuggers.clear() + + +@ops.RegisterGradient("DebugIdentity") +def _identify_gradient_grad(op, dy): + """Gradient function for the DebugIdentity op.""" + # TODO(cais): Allow overriding gradient. + grad_debugger_uuid, orig_tensor_name = _parse_grad_debug_op_name(op.name) + grad_debugger = _gradient_debuggers[grad_debugger_uuid] + grad_debugger.register_gradient_tensor(orig_tensor_name, dy) + return dy + + +def gradient_values_from_dump(grad_debugger, x_tensor, dump): + """Find gradient values from a `DebugDumpDir` object. + + Args: + grad_debugger: the `tf_debug.GradientsDebugger` instance to be used. + x_tensor: (`tf.Tensor`, `tf.Variable` or `str`) The x-tensor object or its + name. x-tensor refers to the independent `tf.Tensor`, i.e., the tensor + on the denominator of the differentiation. + dump: A `tfdbg.DebugDumpDir` object. + + Returns: + If this `GradientsDebugger` instance has the gradient tensor of `x_tensor` + registered: a list of `numpy.ndarray` representing the value of the + gradient tensor from `dump`. The list could be empty, if the gradient + tensor is not executed in the `tf.Session.run()` call that generated + the `dump`. The list could also contain multiple values of the gradient + tensor, e.g., if gradient tensor is computed repeatedly in a + `tf.while_loop` during the run that generated the `dump`. + + Raises: + LookupError: If this `GradientsDebugger` instance does not have the + gradient tensor of `x_tensor` registered. + ValueError: If this `GradientsDebugger` has a `tf.Graph` object that + does not match the `tf.Graph` object of the `dump`. + TypeError: If `x_tensor` is not a `tf.Tensor`, `tf.Variable` or `str`. + """ + # TODO(cais): Use this method in LocalCLIDebugWrapperSession to present the + # gradient tensors to the TFDBG CLI. + + # If possible, verify that the Python graph of the dump and that of this + # GradientsDebugger match. + if (dump.python_graph and grad_debugger.graph and + dump.python_graph != grad_debugger.graph): + raise ValueError( + "This GradientsDebugger instance has a graph (%s) that differs from " + "the graph of the DebugDumpDir object (%s)." % + (grad_debugger.graph, dump.python_graph)) + + gradient_tensor = grad_debugger.gradient_tensor(x_tensor) + node_name, output_slot = debug_data.parse_node_or_tensor_name( + gradient_tensor.name) + + try: + return dump.get_tensors(node_name, output_slot, "DebugIdentity") + except debug_data.WatchKeyDoesNotExistInDebugDumpDirError: + return [] diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py new file mode 100644 index 0000000000..966578320e --- /dev/null +++ b/tensorflow/python/debug/lib/debug_gradients_test.py @@ -0,0 +1,378 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for debug_gradients module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.debug.lib import debug_data +from tensorflow.python.debug.lib import debug_gradients +from tensorflow.python.debug.lib import debug_utils +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest +from tensorflow.python.training import gradient_descent + + +class IdentifyGradientTest(test_util.TensorFlowTestCase): + + def setUp(self): + self.sess = session.Session() + with self.sess: + self.u = variables.Variable(2.0, name="u") + self.v = variables.Variable(3.0, name="v") + self.w = math_ops.multiply(self.u.value(), self.v.value(), name="w") + + def tearDown(self): + ops.reset_default_graph() + debug_gradients.clear_gradient_debuggers() + + def testIdentifyGradientGivesCorrectTensorObjectWithoutContextManager(self): + grad_debugger = debug_gradients.GradientsDebugger() + id_grad_w = grad_debugger.identify_gradient(self.w) + y = math_ops.add(id_grad_w, -1.0, name="y") + + grads = gradients_impl.gradients(y, [self.u, self.v]) + self.assertEqual(2, len(grads)) + u_grad = grads[0] + v_grad = grads[1] + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(5.0, self.sess.run(y)) + self.assertAllClose(3.0, self.sess.run(u_grad)) + self.assertAllClose(2.0, self.sess.run(v_grad)) + + # Fetch the gradient tensor with the x-tensor object. + w_grad = grad_debugger.gradient_tensor(self.w) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + # Fetch the gradient tensor with the x-tensor's name. + w_grad = grad_debugger.gradient_tensor(self.w.name) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + # Fetch the gradient tensor with the x-tensor name. + w_grad = grad_debugger.gradient_tensor(self.w.name) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + def testIdentifyGradientGivesCorrectTensorObjectWithTfGradients(self): + grad_debugger = debug_gradients.GradientsDebugger() + id_grad_w = grad_debugger.identify_gradient(self.w) + y = math_ops.add(id_grad_w, -1.0, name="y") + + with grad_debugger: + grads = gradients_impl.gradients(y, [self.u, self.v]) + self.assertEqual(2, len(grads)) + u_grad = grads[0] + v_grad = grads[1] + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(5.0, self.sess.run(y)) + self.assertAllClose(3.0, self.sess.run(u_grad)) + self.assertAllClose(2.0, self.sess.run(v_grad)) + + # Fetch the gradient tensor with the x-tensor object. + w_grad = grad_debugger.gradient_tensor(self.w) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + # Fetch the gradient tensor with the x-tensor's name. + w_grad = grad_debugger.gradient_tensor(self.w.name) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + # Fetch the gradient tensor with the x-tensor name. + w_grad = grad_debugger.gradient_tensor(self.w.name) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self): + grad_debugger = debug_gradients.GradientsDebugger() + grad_debugger.identify_gradient(self.w) + with self.assertRaisesRegexp( + ValueError, "The graph already contains an op named .*"): + grad_debugger.identify_gradient(self.w) + + def testIdentifyGradientWorksOnMultipleLosses(self): + grad_debugger_1 = debug_gradients.GradientsDebugger() + grad_debugger_2 = debug_gradients.GradientsDebugger() + + y = math_ops.add(self.w, -1.0, name="y") + debug_y = grad_debugger_1.identify_gradient(y) + z1 = math_ops.square(debug_y, name="z1") + + debug_y = grad_debugger_2.identify_gradient(y) + z2 = math_ops.sqrt(debug_y, name="z2") + + with grad_debugger_1: + gradient_descent.GradientDescentOptimizer(0.1).minimize(z1) + with grad_debugger_2: + gradient_descent.GradientDescentOptimizer(0.1).minimize(z2) + + dz1_dy = grad_debugger_1.gradient_tensor(y) + dz2_dy = grad_debugger_2.gradient_tensor(y) + self.assertIsInstance(dz1_dy, ops.Tensor) + self.assertIsInstance(dz2_dy, ops.Tensor) + self.assertIsNot(dz1_dy, dz2_dy) + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(5.0 ** 2, self.sess.run(z1)) + self.assertAllClose(5.0 ** 0.5, self.sess.run(z2)) + self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy)) + self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy)) + + def testIdentifyGradientRaisesLookupErrorForUnknownXTensor(self): + grad_debugger_1 = debug_gradients.GradientsDebugger() + grad_debugger_2 = debug_gradients.GradientsDebugger() + id_grad_w = grad_debugger_1.identify_gradient(self.w) + y = math_ops.add(id_grad_w, -1.0, name="y") + + # There are >1 gradient debuggers registered, and grad_debugger is not used + # as a context manager here, so the gradient w.r.t. self.w will not be + # registered. + gradients_impl.gradients(y, [self.u, self.v]) + + with self.assertRaisesRegexp( + LookupError, + r"This GradientsDebugger has not received any gradient tensor for "): + grad_debugger_1.gradient_tensor(self.w) + with self.assertRaisesRegexp( + LookupError, + r"This GradientsDebugger has not received any gradient tensor for "): + grad_debugger_2.gradient_tensor(self.w) + + def testIdentifyGradientRaisesTypeErrorForNonTensorOrTensorNameInput(self): + grad_debugger = debug_gradients.GradientsDebugger() + with self.assertRaisesRegexp( + TypeError, + r"x_tensor must be a str or tf\.Tensor or tf\.Variable, but instead " + r"has type .*Operation.*"): + grad_debugger.gradient_tensor(variables.global_variables_initializer()) + + def testIdentifyGradientTensorWorksWithGradientDescentOptimizer(self): + grad_debugger = debug_gradients.GradientsDebugger() + id_grad_w = grad_debugger.identify_gradient(self.w) + y = math_ops.add(id_grad_w, -1.0, name="y") + + with grad_debugger: + gradient_descent.GradientDescentOptimizer(0.1).minimize(y) + + self.sess.run(variables.global_variables_initializer()) + + # Fetch the gradient tensor with the x-tensor object. + w_grad = grad_debugger.gradient_tensor(self.w) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + def testWatchGradientsByXTensorNamesWorks(self): + y = math_ops.add(self.w, -1.0, name="y") + + # The constructrion of the forward graph has completed. + # But we can still get the gradient tensors by using + # watch_gradients_by_tensor_names(). + grad_debugger = debug_gradients.GradientsDebugger() + with grad_debugger.watch_gradients_by_tensor_names(self.sess.graph, "w:0$"): + grads = gradients_impl.gradients(y, [self.u, self.v]) + self.assertEqual(2, len(grads)) + u_grad = grads[0] + v_grad = grads[1] + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(5.0, self.sess.run(y)) + self.assertAllClose(3.0, self.sess.run(u_grad)) + self.assertAllClose(2.0, self.sess.run(v_grad)) + + w_grad = grad_debugger.gradient_tensor(self.w) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + w_grad = grad_debugger.gradient_tensor("w:0") + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + def testWatchGradientsByXTensorNamesWorksWithoutContextManager(self): + y = math_ops.add(self.w, -1.0, name="y") + + # The constructrion of the forward graph has completed. + # But we can still get the gradient tensors by using + # watch_gradients_by_tensor_names(). + grad_debugger = debug_gradients.GradientsDebugger() + grad_debugger.watch_gradients_by_tensor_names(self.sess.graph, "w:0$") + grads = gradients_impl.gradients(y, [self.u, self.v]) + self.assertEqual(2, len(grads)) + u_grad = grads[0] + v_grad = grads[1] + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(5.0, self.sess.run(y)) + self.assertAllClose(3.0, self.sess.run(u_grad)) + self.assertAllClose(2.0, self.sess.run(v_grad)) + + w_grad = grad_debugger.gradient_tensor(self.w) + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + w_grad = grad_debugger.gradient_tensor("w:0") + self.assertIsInstance(w_grad, ops.Tensor) + self.assertAllClose(1.0, self.sess.run(w_grad)) + + def testWatchGradientsWorksOnRefTensor(self): + y = math_ops.add(self.w, -1.0, name="y") + + grad_debugger = debug_gradients.GradientsDebugger() + with grad_debugger.watch_gradients_by_tensor_names(self.sess.graph, "u:0$"): + grads = gradients_impl.gradients(y, [self.u, self.v]) + self.assertEqual(2, len(grads)) + u_grad = grads[0] + v_grad = grads[1] + + self.assertIs(u_grad, grad_debugger.gradient_tensor("u:0")) + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(3.0, self.sess.run(u_grad)) + self.assertAllClose(2.0, self.sess.run(v_grad)) + self.assertAllClose( + 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0"))) + + def testWatchGradientsWorksOnMultipleTensors(self): + y = math_ops.add(self.w, -1.0, name="y") + + grad_debugger = debug_gradients.GradientsDebugger() + with grad_debugger.watch_gradients_by_tensor_names(self.sess.graph, + "(u|w):0$"): + grads = gradients_impl.gradients(y, [self.u, self.v]) + self.assertEqual(2, len(grads)) + u_grad = grads[0] + + self.assertEqual(2, len(grad_debugger.gradient_tensors())) + self.assertIs(u_grad, grad_debugger.gradient_tensor("u:0")) + self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), ops.Tensor) + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose( + 1.0, self.sess.run(grad_debugger.gradient_tensor("w:0"))) + self.assertAllClose( + 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0"))) + + def testWatchGradientsByXTensorsWorks(self): + y = math_ops.add(self.w, -1.0, name="foo/y") + z = math_ops.square(y, name="foo/z") + + # The constructrion of the forward graph has completed. + # But we can still get the gradient tensors by using + # watch_gradients_by_x_tensors(). + grad_debugger = debug_gradients.GradientsDebugger() + with grad_debugger.watch_gradients_by_tensors( + self.sess.graph, [self.w, self.u, y]): + gradient_descent.GradientDescentOptimizer(0.1).minimize(z) + + self.assertEqual(3, len(grad_debugger.gradient_tensors())) + u_grad = grad_debugger.gradient_tensor(self.u) + w_grad = grad_debugger.gradient_tensor(self.w) + y_grad = grad_debugger.gradient_tensor(y) + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(10.0, self.sess.run(y_grad)) + self.assertAllClose(10.0, self.sess.run(w_grad)) + self.assertAllClose(30.0, self.sess.run(u_grad)) + + def testWatchGradientsByTensorCanWorkOnMultipleLosses(self): + y = math_ops.add(self.w, -1.0, name="y") + z1 = math_ops.square(y, name="z1") + z2 = math_ops.sqrt(y, name="z2") + + grad_debugger_1 = debug_gradients.GradientsDebugger() + with grad_debugger_1.watch_gradients_by_tensors(self.sess.graph, y): + gradient_descent.GradientDescentOptimizer(0.1).minimize(z1) + + grad_debugger_2 = debug_gradients.GradientsDebugger() + with grad_debugger_2.watch_gradients_by_tensors(self.sess.graph, y): + gradient_descent.GradientDescentOptimizer(0.1).minimize(z2) + + dz1_dy = grad_debugger_1.gradient_tensor(y) + dz2_dy = grad_debugger_2.gradient_tensor(y) + self.assertIsInstance(dz1_dy, ops.Tensor) + self.assertIsInstance(dz2_dy, ops.Tensor) + self.assertIsNot(dz1_dy, dz2_dy) + + self.sess.run(variables.global_variables_initializer()) + self.assertAllClose(5.0 ** 2, self.sess.run(z1)) + self.assertAllClose(5.0 ** 0.5, self.sess.run(z2)) + self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy)) + self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy)) + + def testGradientsValuesFromDumpWorks(self): + y = math_ops.add(self.w, -1.0, name="y") + z = math_ops.square(y, name="z") + + grad_debugger = debug_gradients.GradientsDebugger() + with grad_debugger.watch_gradients_by_tensors( + self.sess.graph, [self.w, self.u, y]): + train_op = gradient_descent.GradientDescentOptimizer(0.1).minimize(z) + + self.sess.run(variables.global_variables_initializer()) + + run_options = config_pb2.RunOptions(output_partition_graphs=True) + dump_dir = tempfile.mkdtemp() + debug_url = "file://" + dump_dir + debug_utils.watch_graph( + run_options, + self.sess.graph, + debug_urls=debug_url) + run_metadata = config_pb2.RunMetadata() + self.sess.run(train_op, options=run_options, run_metadata=run_metadata) + + dump = debug_data.DebugDumpDir( + dump_dir, partition_graphs=run_metadata.partition_graphs) + dump.set_python_graph(self.sess.graph) + + y_grad_values = debug_gradients.gradient_values_from_dump( + grad_debugger, y, dump) + self.assertEqual(1, len(y_grad_values)) + self.assertAllClose(10.0, y_grad_values[0]) + + w_grad_values = debug_gradients.gradient_values_from_dump( + grad_debugger, self.w, dump) + self.assertEqual(1, len(w_grad_values)) + self.assertAllClose(10.0, w_grad_values[0]) + + u_grad_values = debug_gradients.gradient_values_from_dump( + grad_debugger, self.u, dump) + self.assertEqual(1, len(u_grad_values)) + self.assertAllClose(30.0, u_grad_values[0]) + + with self.assertRaisesRegexp( + LookupError, + r"This GradientsDebugger has not received any gradient tensor for " + r"x-tensor v:0"): + debug_gradients.gradient_values_from_dump(grad_debugger, self.v, dump) + + # Cleanup. + shutil.rmtree(dump_dir) + + +if __name__ == "__main__": + googletest.main() -- cgit v1.2.3