diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-02-26 14:38:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-26 14:44:21 -0800 |
commit | 153e10a037c5e348834108ff46d9dccdf0cfb9a9 (patch) | |
tree | 411a202a638a10e08092b243cadf05a6293705be | |
parent | a80896d3b3a2358f324dc4cd429409ea9acc8a09 (diff) |
Enable de/serialization of nested control flow.
This is a follow-up to the previous commit
(https://github.com/tensorflow/tensorflow/commit/23851760b7b099214bdd4f1b88156d7ac2bdd2a2).
It adds the new proto schemas, enables the behavior for reading and
writing the new protos, and adds a test for de/serializing nested
while loops.
There's still a bug preventing deserializing conds, which will be addressed
in another change.
PiperOrigin-RevId: 187082713
-rw-r--r-- | tensorflow/core/protobuf/control_flow.proto | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 54 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 56 |
3 files changed, 88 insertions, 39 deletions
diff --git a/tensorflow/core/protobuf/control_flow.proto b/tensorflow/core/protobuf/control_flow.proto index 2c9476a08a..3c05b4f0e2 100644 --- a/tensorflow/core/protobuf/control_flow.proto +++ b/tensorflow/core/protobuf/control_flow.proto @@ -17,6 +17,15 @@ message ValuesDef { map<string, string> external_values = 2; } +// Container for any kind of control flow context. Any other control flow +// contexts that are added below should also be added here. +message ControlFlowContextDef { + oneof ctxt { + CondContextDef cond_ctxt = 1; + WhileContextDef while_ctxt = 2; + } +} + // Protocol buffer representing a CondContext object. message CondContextDef { // Name of the context. @@ -33,6 +42,9 @@ message CondContextDef { // Values and external values in control flow context. ValuesDef values_def = 5; + + // Contexts contained inside this context (e.g. nested conds). + repeated ControlFlowContextDef nested_contexts = 6; } // Protocol buffer representing a WhileContext object. @@ -70,5 +82,8 @@ message WhileContextDef { // Optional name of the maximum_iterations tensor. string maximum_iterations_name = 11; - // Next available id: 12. + // Contexts contained inside this context (e.g. nested whiles). + repeated ControlFlowContextDef nested_contexts = 12; + + // Next available id: 13. } diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 8d5ab72670..85944efbe8 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1767,13 +1767,9 @@ class CondContext(ControlFlowContext): context_def.branch = self._branch context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def( export_scope)) - # TODO(b/72868227): enable this once the corresponding control_flow.proto - # changes have been checked in (they aren't checked in and this is - # disabled for now to ensure forwards compatibility). - if False: # pylint: disable=using-constant-test - for nested in self._nested_contexts: - nested_def = context_def.nested_contexts.add() - nested.to_control_flow_context_def(nested_def) + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -1785,14 +1781,10 @@ class CondContext(ControlFlowContext): ret = CondContext(context_def=context_def, import_scope=import_scope) - # TODO(b/72868227): remove "if hasattr(...)" once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is here for now to ensure forwards compatibility). - if hasattr(context_def, "nested_contexts"): - ret.Enter() - for nested_def in context_def.nested_contexts: - from_control_flow_context_def(nested_def) - ret.Exit() + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def) + ret.Exit() return ret def to_control_flow_context_def(self, context_def, export_scope=None): @@ -2110,10 +2102,7 @@ def cond(pred, # Only add non-nested conds to the collection. Any nested control flow will # be encapsulated in the root context. assert context_t.outer_context == context_f.outer_context - # TODO(b/72868227): remove "if True..." once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if True or context_t.outer_context is None: + if context_t.outer_context is None: ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) @@ -2336,13 +2325,9 @@ class WhileContext(ControlFlowContext): context_def.values_def.MergeFrom( super(WhileContext, self)._to_values_def( export_scope=export_scope)) - # TODO(b/72868227): remove "if True..." once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if False: # pylint: disable=using-constant-test - for nested in self._nested_contexts: - nested_def = context_def.nested_contexts.add() - nested.to_control_flow_context_def(nested_def) + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -2364,14 +2349,10 @@ class WhileContext(ControlFlowContext): """ ret = WhileContext(context_def=context_def, import_scope=import_scope) - # TODO(b/72868227): remove "if hasattr(...)" once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if hasattr(context_def, "nested_contexts"): - ret.Enter() - for nested_def in context_def.nested_contexts: - from_control_flow_context_def(nested_def, import_scope=import_scope) - ret.Exit() + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def, import_scope=import_scope) + ret.Exit() return ret def GetWhileContext(self): @@ -3216,10 +3197,7 @@ def while_loop(cond, swap_memory=swap_memory) # Only add non-nested loops to the collection. Any nested control flow will # be encapsulated in the root context. - # TODO(b/72868227): enable condition once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if True or loop_context.outer_context is None: + if loop_context.outer_context is None: ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) if maximum_iterations is not None: diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index f00f98db00..b366ed30f3 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -53,6 +53,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables @@ -2040,6 +2041,61 @@ class MetaGraphTest(test.TestCase): self._testGraphExtensionRestore(test_dir) self._testRestoreFromTrainGraphWithControlContext(test_dir) + def testNestedWhileLoops(self): + test_dir = self._get_test_dir("nested_whiles") + filename = os.path.join(test_dir, "metafile") + saver_ckpt = os.path.join(test_dir, "saver.ckpt") + + # Create two simple nested while loops. + with ops_lib.Graph().as_default(): + def body(i, x): + _, r = control_flow_ops.while_loop(lambda j, y: j < 3, + lambda j, y: (j + 1, y + x), + [0, 0]) + return i + 1, x + r + + var = variables.Variable(0) + var_name = var.name + + _, output = control_flow_ops.while_loop(lambda i, x: i < 5, body, + [0, var]) + output_name = output.name + + init_op = variables.global_variables_initializer() + + # Generate a MetaGraphDef containing the nested loops. + with session.Session() as sess: + sess.run(init_op) + sess.run(output) + saver = saver_module.Saver() + saver.save(sess, saver_ckpt) + saver.export_meta_graph(filename) + + # Build and run the gradients of the nested while loop. We use this below + # to verify that the gradients are correct with an imported MetaGraphDef. + grad = gradients_impl.gradients([output], [var]) + with session.Session() as sess: + sess.run(init_op) + expected_grad_value = sess.run(grad) + + # Restore the MetaGraphDef into a new Graph. + with ops_lib.Graph().as_default(): + with session.Session() as sess: + saver = saver_module.import_meta_graph(filename) + saver.restore(sess, saver_ckpt) + + # Make sure we can still build gradients and get the same result. + var = ops_lib.get_default_graph().get_tensor_by_name(var_name) + output = ops_lib.get_default_graph().get_tensor_by_name(output_name) + grad = gradients_impl.gradients([output], [var]) + + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + actual_grad_value = sess.run(grad) + self.assertEqual(expected_grad_value, actual_grad_value) + def testStrippedOpListDef(self): with self.test_session(): # Creates a graph. |