aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/debug
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-07 10:30:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 10:34:33 -0800
commit0e3b2b5805871f6d3dffff486a113b3c26c33b1f (patch)
tree4963e4e06bda8eab212171086c180556aa10bf78 /tensorflow/python/debug
parentf6010282b872b2c74eaf710a879b5b4a1756ae17 (diff)
Improve model_pruner:
* Actually remove nodes marked for removal if fetches are known. * Remove trivial nodes even in the presence of control inputs, except for Identity nodes when a) they are anchored on an Identity following a Switch node and removal would require anchoring a control identity on the Switch, or b) they have control inputs and feed a Merge node. * Remove nodes only when in_degree * out_degree <= in_degree + out_degree. Move input deduping utility function to utils.{h,cc}. PiperOrigin-RevId: 184858685
Diffstat (limited to 'tensorflow/python/debug')
-rw-r--r--tensorflow/python/debug/lib/debug_gradients_test.py42
1 files changed, 20 insertions, 22 deletions
diff --git a/tensorflow/python/debug/lib/debug_gradients_test.py b/tensorflow/python/debug/lib/debug_gradients_test.py
index c1e9869d97..01867fc69d 100644
--- a/tensorflow/python/debug/lib/debug_gradients_test.py
+++ b/tensorflow/python/debug/lib/debug_gradients_test.py
@@ -40,6 +40,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
def setUp(self):
rewriter_config = rewriter_config_pb2.RewriterConfig(
+ disable_model_pruning=True,
dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
config = config_pb2.ConfigProto(graph_options=graph_options)
@@ -117,8 +118,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
def testCallingIdentifyGradientTwiceWithTheSameGradientsDebuggerErrors(self):
grad_debugger = debug_gradients.GradientsDebugger()
grad_debugger.identify_gradient(self.w)
- with self.assertRaisesRegexp(
- ValueError, "The graph already contains an op named .*"):
+ with self.assertRaisesRegexp(ValueError,
+ "The graph already contains an op named .*"):
grad_debugger.identify_gradient(self.w)
def testIdentifyGradientWorksOnMultipleLosses(self):
@@ -144,10 +145,10 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
self.assertIsNot(dz1_dy, dz2_dy)
self.sess.run(variables.global_variables_initializer())
- self.assertAllClose(5.0 ** 2, self.sess.run(z1))
- self.assertAllClose(5.0 ** 0.5, self.sess.run(z2))
+ self.assertAllClose(5.0**2, self.sess.run(z1))
+ self.assertAllClose(5.0**0.5, self.sess.run(z2))
self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
- self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy))
+ self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy))
def testIdentifyGradientRaisesLookupErrorForUnknownXTensor(self):
grad_debugger_1 = debug_gradients.GradientsDebugger()
@@ -259,8 +260,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
self.sess.run(variables.global_variables_initializer())
self.assertAllClose(3.0, self.sess.run(u_grad))
self.assertAllClose(2.0, self.sess.run(v_grad))
- self.assertAllClose(
- 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0")))
+ self.assertAllClose(3.0, self.sess.run(
+ grad_debugger.gradient_tensor("u:0")))
def testWatchGradientsWorksOnMultipleTensors(self):
y = math_ops.add(self.w, -1.0, name="y")
@@ -277,10 +278,10 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
self.assertIsInstance(grad_debugger.gradient_tensor("w:0"), ops.Tensor)
self.sess.run(variables.global_variables_initializer())
- self.assertAllClose(
- 1.0, self.sess.run(grad_debugger.gradient_tensor("w:0")))
- self.assertAllClose(
- 3.0, self.sess.run(grad_debugger.gradient_tensor("u:0")))
+ self.assertAllClose(1.0, self.sess.run(
+ grad_debugger.gradient_tensor("w:0")))
+ self.assertAllClose(3.0, self.sess.run(
+ grad_debugger.gradient_tensor("u:0")))
def testWatchGradientsByXTensorsWorks(self):
y = math_ops.add(self.w, -1.0, name="foo/y")
@@ -290,8 +291,8 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
# But we can still get the gradient tensors by using
# watch_gradients_by_x_tensors().
grad_debugger = debug_gradients.GradientsDebugger()
- with grad_debugger.watch_gradients_by_tensors(
- self.sess.graph, [self.w, self.u, y]):
+ with grad_debugger.watch_gradients_by_tensors(self.sess.graph,
+ [self.w, self.u, y]):
gradient_descent.GradientDescentOptimizer(0.1).minimize(z)
self.assertEqual(3, len(grad_debugger.gradient_tensors()))
@@ -324,18 +325,18 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
self.assertIsNot(dz1_dy, dz2_dy)
self.sess.run(variables.global_variables_initializer())
- self.assertAllClose(5.0 ** 2, self.sess.run(z1))
- self.assertAllClose(5.0 ** 0.5, self.sess.run(z2))
+ self.assertAllClose(5.0**2, self.sess.run(z1))
+ self.assertAllClose(5.0**0.5, self.sess.run(z2))
self.assertAllClose(2.0 * 5.0, self.sess.run(dz1_dy))
- self.assertAllClose(0.5 * (5.0 ** -0.5), self.sess.run(dz2_dy))
+ self.assertAllClose(0.5 * (5.0**-0.5), self.sess.run(dz2_dy))
def testGradientsValuesFromDumpWorks(self):
y = math_ops.add(self.w, -1.0, name="y")
z = math_ops.square(y, name="z")
grad_debugger = debug_gradients.GradientsDebugger()
- with grad_debugger.watch_gradients_by_tensors(
- self.sess.graph, [self.w, self.u, y]):
+ with grad_debugger.watch_gradients_by_tensors(self.sess.graph,
+ [self.w, self.u, y]):
train_op = gradient_descent.GradientDescentOptimizer(0.1).minimize(z)
self.sess.run(variables.global_variables_initializer())
@@ -343,10 +344,7 @@ class IdentifyGradientTest(test_util.TensorFlowTestCase):
run_options = config_pb2.RunOptions(output_partition_graphs=True)
dump_dir = tempfile.mkdtemp()
debug_url = "file://" + dump_dir
- debug_utils.watch_graph(
- run_options,
- self.sess.graph,
- debug_urls=debug_url)
+ debug_utils.watch_graph(run_options, self.sess.graph, debug_urls=debug_url)
run_metadata = config_pb2.RunMetadata()
self.assertAllClose(2.0, self.sess.run(self.u))
self.sess.run(train_op, options=run_options, run_metadata=run_metadata)