From e91ea774cfce323ef5812274f92fa2f13d2f3637 Mon Sep 17 00:00:00 2001 From: Sherry Moore Date: Tue, 13 Sep 2016 09:08:10 -0800 Subject: Added to_proto and from_proto for CondContext and WhileContext. Change: 133013495 --- tensorflow/core/protobuf/control_flow.proto | 66 ++++++++ tensorflow/python/framework/ops.py | 4 + tensorflow/python/ops/control_flow_ops.py | 205 +++++++++++++++++++++++-- tensorflow/python/ops/control_flow_ops_test.py | 31 ++++ tensorflow/python/training/saver.py | 12 +- tensorflow/python/training/saver_test.py | 55 ++++++- 6 files changed, 350 insertions(+), 23 deletions(-) create mode 100644 tensorflow/core/protobuf/control_flow.proto diff --git a/tensorflow/core/protobuf/control_flow.proto b/tensorflow/core/protobuf/control_flow.proto new file mode 100644 index 0000000000..24f42322c0 --- /dev/null +++ b/tensorflow/core/protobuf/control_flow.proto @@ -0,0 +1,66 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ControlFlowProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Control flow context related protocol buffers. + +// Protocol buffer representing the values in ControlFlowContext. +message ValuesDef { + // Value names that have been seen in this context. + repeated string values = 1; + + // Value names referenced by but external to this context. + map external_values = 2; +} + +// Protocol buffer representing a CondContext object. +message CondContextDef { + // Name of the context. + string context_name = 1; + + // Name of the pred tensor. + string pred_name = 2; + + // Name of the pivot tensor. + string pivot_name = 3; + + // Branch prediction. 0 or 1. + int32 branch = 4; + + // Values and external values in control flow context. + ValuesDef values_def = 5; +} + +// Protocol buffer representing a WhileContext object. +message WhileContextDef { + // Name of the context. + string context_name = 1; + + // The number of iterations allowed to run in parallel. + int32 parallel_iterations = 2; + + // Whether backprop is enabled for this while loop. + bool back_prop = 3; + + // Whether GPU-CPU memory swap is enabled for this loop. + bool swap_memory = 4; + + // Name of the pivot tensor. + string pivot_name = 5; + + // Name of the pivot_for_pred tensor. + string pivot_for_pred_name = 6; + + // Name of the pivot_for_body tensor. + string pivot_for_body_name = 7; + + // List of names for exit tensors. + repeated string loop_exit_names = 8; + + // Values and external values in control flow context. + ValuesDef values_def = 9; +} diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 8cd5b0fef1..caa9695c8a 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -3948,6 +3948,10 @@ class GraphKeys(object): SUMMARY_OP = "summary_op" GLOBAL_STEP = "global_step" + # Key for control flow context. + COND_CONTEXT = "cond_context" + WHILE_CONTEXT = "while_context" + def add_to_collection(name, value): """Wrapper for `Graph.add_to_collection()` using the default graph. diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index f1dc8b3515..6b29f18398 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -75,6 +75,7 @@ import collections import six from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.core.protobuf import control_flow_pb2 from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -1333,13 +1334,35 @@ class ControlFlowContext(object): Pushed and popped by ctxt.Enter() and ctxt.Exit() """ - def __init__(self): + def __init__(self, values_def=None): self._outer_context = ops.get_default_graph()._get_control_flow_context() self._context_stack = [] - # Values that have been already seen in this context. - self._values = set() - # Values referenced by but external to this context. + if values_def: + self._init_values_from_proto(values_def) + else: + # Values that have been already seen in this context. + self._values = set() + # Values referenced by but external to this context. + self._external_values = {} + + def _init_values_from_proto(self, values_def): + """Initializes values and external_values from `ValuesDef` protocol buffer. + + Args: + values_def: `ValuesDef` protocol buffer. + """ + assert isinstance(values_def, control_flow_pb2.ValuesDef) + self._values = set(values_def.values) + g = ops.get_default_graph() self._external_values = {} + for k, v in values_def.external_values.items(): + self._external_values[k] = g.as_graph_element(v) + op_names = set([op.split(":")[0] + for op in self._values - set(self._external_values)]) + for op in op_names: + # pylint: disable=protected-access + g.as_graph_element(op)._set_control_flow_context(self) + # pylint: enable=protected-access @property def outer_context(self): @@ -1354,6 +1377,23 @@ class ControlFlowContext(object): def back_prop(self): raise NotImplementedError("Abstract method") + def _to_proto(self): + """Converts the values to a `ValuesDef` protocol buffer. + + Returns: + A `ValuesDef` protocol buffer. + """ + values_def = control_flow_pb2.ValuesDef() + values_def.values.extend([v for v in sorted(self._values)]) + for k, v in self._external_values.items(): + values_def.external_values[k] = v.name + return values_def + + @staticmethod + def _from_proto(values_def): + """Returns a `ControlFlowContext` created from `values_def`.""" + return ControlFlowContext(values_def=values_def) + def AddName(self, name): self._values.add(name) @@ -1428,15 +1468,41 @@ class ControlFlowContext(object): class CondContext(ControlFlowContext): """The context for the conditional construct.""" - def __init__(self, pred, pivot, branch): - ControlFlowContext.__init__(self) - self._pred = pred # The boolean tensor for the cond predicate - self._pivot = pivot # The predicate tensor in this branch - self._branch = branch # 0 or 1 representing this branch + def __init__(self, pred=None, pivot=None, branch=None, + name="cond_text", context_def=None): + self._name = ops.get_default_graph().unique_name(name) - # Values considered to have been already seen in this context. - self._values.add(pred.name) - self._values.add(pivot.name) + if context_def: + self._init_from_proto(context_def) + else: + # Initializes the default fields. + ControlFlowContext.__init__(self) + self._pred = pred # The boolean tensor for the cond predicate + self._pivot = pivot # The predicate tensor in this branch + self._branch = branch # 0 or 1 representing this branch + + # Values considered to have been already seen in this context. + self._values.add(pred.name) + self._values.add(pivot.name) + + def _init_from_proto(self, context_def): + """Creates a new `CondContext` from protocol buffer. + + Args: + context_def: `CondContextDef` protocol buffer. + """ + assert isinstance(context_def, control_flow_pb2.CondContextDef) + # Create from context_def. + g = ops.get_default_graph() + self._name = context_def.context_name + self._pred = g.as_graph_element(context_def.pred_name) + self._pivot = g.as_graph_element(context_def.pivot_name) + self._branch = context_def.branch + super(CondContext, self).__init__(values_def=context_def.values_def) + + @property + def name(self): + return self._name @property def pred(self): @@ -1465,6 +1531,26 @@ class CondContext(ControlFlowContext): def GetControlPivot(self): return self._pivot + def to_proto(self): + """Converts a `CondContext` to a `CondContextDef` protocol buffer. + + Returns: + A `CondContextDef` protocol buffer. + """ + context_def = control_flow_pb2.CondContextDef() + context_def.context_name = self.name + context_def.pred_name = self._pred.name + context_def.pivot_name = self._pivot.name + context_def.branch = self._branch + context_def.values_def.MergeFrom(super(CondContext, self)._to_proto()) + + return context_def + + @staticmethod + def from_proto(context_def): + """Returns a `CondContext` object created from `context_def`.""" + return CondContext(context_def=context_def) + def AddValue(self, val): """Add `val` to the current context and its outer context recursively.""" if val.name in self._values: @@ -1650,6 +1736,11 @@ def cond(pred, fn1, fn2, name=None): "%s, %s" % (val_x.dtype.name, val_y.dtype.name)) merges = [merge([x[0], x[1]])[0] for x in zip(res_f, res_t)] merges = _convert_flows_to_tensorarrays(orig_res, merges) + + # Add to collections + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) + ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) + return merges[0] if len(merges) == 1 else merges @@ -1659,9 +1750,30 @@ def cond(pred, fn1, fn2, name=None): class WhileContext(ControlFlowContext): """The context for the loop construct.""" - def __init__(self, parallel_iterations, back_prop, swap_memory, name, - grad_state=None): - ControlFlowContext.__init__(self) + def __init__(self, parallel_iterations=10, back_prop=True, swap_memory=False, + name="while_context", grad_state=None, context_def=None): + if context_def: + self._init_from_proto(context_def) + else: + ControlFlowContext.__init__(self) + self._init_from_args(parallel_iterations, back_prop, swap_memory, + name) + # The gradient loop state. + self._grad_state = grad_state + + def _init_from_args(self, parallel_iterations, back_prop, swap_memory, + name): + """Creates a new `WhileContext` from arguments. + + Args: + parallel_iterations: The number of iterations allowed to run in parallel. + back_prop: Whether backprop is enabled for this while loop. + swap_memory: Whether GPU-CPU memory swap is enabled for this loop. + name: Optional name prefix for the returned tensors. + + Raises: + ValueError: If `parallel_iterations` has invalid value. + """ if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): raise ValueError("`parallel_iterations` must be a positive integer: " "%s" % parallel_iterations) @@ -1678,8 +1790,30 @@ class WhileContext(ControlFlowContext): self._pivot = None # The list of exit tensors for loop variables. self._loop_exits = None - # The gradient loop state. - self._grad_state = grad_state + + def _init_from_proto(self, context_def): + """Creates a new `WhileContext` from protocol buffer. + + Args: + context_def: `WhileContextDef` protocol buffer. + """ + assert isinstance(context_def, control_flow_pb2.WhileContextDef) + # Create from context_def. + g = ops.get_default_graph() + self._name = context_def.context_name + self._parallel_iterations = context_def.parallel_iterations + self._back_prop = context_def.back_prop + self._swap_memory = context_def.swap_memory + self._pivot_for_pred = g.as_graph_element(context_def.pivot_for_pred_name) + # We use this node to control constants created by the body lambda. + self._pivot_for_body = g.as_graph_element(context_def.pivot_for_body_name) + # The boolean tensor for loop termination condition. Used in code + # generation for gradient computation. + self._pivot = g.as_graph_element(context_def.pivot_name) + # The list of exit tensors for loop variables. + self._loop_exits = [g.as_graph_element(exit_name) + for exit_name in context_def.loop_exit_names] + super(WhileContext, self).__init__(values_def=context_def.values_def) @property def name(self): @@ -1715,6 +1849,31 @@ class WhileContext(ControlFlowContext): """The gradient loop state.""" return self._grad_state + def to_proto(self): + """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. + + Returns: + A `WhileContextDef` protocol buffer. + """ + context_def = control_flow_pb2.WhileContextDef() + context_def.context_name = self.name + context_def.parallel_iterations = self._parallel_iterations + context_def.back_prop = self._back_prop + context_def.swap_memory = self._swap_memory + context_def.pivot_for_pred_name = self._pivot_for_pred.name + context_def.pivot_for_body_name = self._pivot_for_body.name + context_def.pivot_name = self._pivot.name + if self._loop_exits: + context_def.loop_exit_names.extend([l.name for l in self._loop_exits]) + context_def.values_def.MergeFrom(super(WhileContext, self)._to_proto()) + + return context_def + + @staticmethod + def from_proto(context_def): + """Returns a `WhileContext` object created from `context_def`.""" + return WhileContext(context_def=context_def) + def GetWhileContext(self): return self @@ -2357,6 +2516,7 @@ def while_loop(cond, body, loop_vars, shape_invariants=None, nest.assert_same_structure(loop_vars, shape_invariants) context = WhileContext(parallel_iterations, back_prop, swap_memory, name) + ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, context) result = context.BuildLoop(cond, body, loop_vars, shape_invariants) return result @@ -2792,3 +2952,14 @@ ops.RegisterShape("RefMerge")(_MergeShape) ops.RegisterShape("RefSelect")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("RefSwitch")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("Switch")(common_shapes.call_cpp_shape_fn) + + +ops.register_proto_function(ops.GraphKeys.COND_CONTEXT, + proto_type=control_flow_pb2.CondContextDef, + to_proto=CondContext.to_proto, + from_proto=CondContext.from_proto) + +ops.register_proto_function(ops.GraphKeys.WHILE_CONTEXT, + proto_type=control_flow_pb2.WhileContextDef, + to_proto=WhileContext.to_proto, + from_proto=WhileContext.from_proto) diff --git a/tensorflow/python/ops/control_flow_ops_test.py b/tensorflow/python/ops/control_flow_ops_test.py index f4352fcdb5..e66f28062a 100644 --- a/tensorflow/python/ops/control_flow_ops_test.py +++ b/tensorflow/python/ops/control_flow_ops_test.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import standard_ops as tf from tensorflow.python.platform import googletest from tensorflow.python.training import momentum +from tensorflow.python.util.protobuf import compare class GroupTestCase(TensorFlowTestCase): @@ -226,5 +227,35 @@ class SwitchTestCase(TensorFlowTestCase): self.assertAllEqual(grad, [1] * 3) +class ContextTest(TensorFlowTestCase): + + def testCondContext(self): + with self.test_session() as sess: + x = tf.constant(2) + y = tf.constant(5) + control_flow_ops.cond(tf.less(x, y), + lambda: tf.mul(x, 17), + lambda: tf.add(y, 23)) + for op in sess.graph.get_operations(): + c = op._get_control_flow_context() + if c: + compare.ProtoEq( + c.to_proto(), + control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto()) + + def testWhileContext(self): + with self.test_session() as sess: + i = tf.constant(0) + c = lambda i: tf.less(i, 10) + b = lambda i: tf.add(i, 1) + tf.while_loop(c, b, [i]) + for op in sess.graph.get_operations(): + c = op._get_control_flow_context() + if c: + compare.ProtoEq( + c.to_proto(), + control_flow_ops.WhileContext.from_proto(c.to_proto()).to_proto()) + + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 2afd879b0c..d5d8daf3b8 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -619,7 +619,17 @@ class BaseSaverBuilder(object): restore_op = self._AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape) - assert restore_op.name.endswith("restore_all"), restore_op.name + # In the following use case, it's possible to have restore_ops be called + # something else: + # - Build inference graph and export a meta_graph. + # - Import the inference meta_graph + # - Extend the inference graph to a train graph. + # - Export a new meta_graph. + # Now the second restore_op will be called "restore_all_1". + # As such, comment out the assert for now until we know whether supporting + # such usage model makes sense. + # + # assert restore_op.name.endswith("restore_all"), restore_op.name return saver_pb2.SaverDef( filename_tensor_name=filename_tensor.name, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 27147096fb..e3bbf88bfc 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -22,6 +22,7 @@ import math import os.path import time import contextlib +import random import shutil import tempfile @@ -37,6 +38,7 @@ from tensorflow.core.protobuf import queue_runner_pb2 from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import gfile from tensorflow.python.training import saver as saver_module from tensorflow.python.util import compat @@ -1211,7 +1213,13 @@ class MetaGraphTest(tf.test.TestCase): filename = os.path.join(test_dir, "metafile") with self.test_session(): # Creates a graph. - v0 = tf.Variable(10.0, name="v0") + v0 = tf.Variable(1.0, name="v0") + control_flow_ops.cond(tf.less(v0, 10), + lambda: tf.add(v0, 1), + lambda: tf.sub(v0, 1)) + control_flow_ops.while_loop(lambda i: tf.less(i, 10), + lambda i: tf.add(i, 1), + [v0]) var = tf.Variable(tf.constant(0, dtype=tf.int64)) count_up_to = var.count_up_to(3) input_queue = tf.FIFOQueue(30, tf.float32, shared_name="collection_queue") @@ -1240,7 +1248,7 @@ class MetaGraphTest(tf.test.TestCase): self.assertTrue(meta_graph_def.HasField("saver_def")) self.assertTrue(meta_graph_def.HasField("graph_def")) collection_def = meta_graph_def.collection_def - self.assertEqual(len(collection_def), 10) + self.assertEqual(len(collection_def), 12) with tf.Graph().as_default(): # Restores from MetaGraphDef. @@ -1418,7 +1426,13 @@ class MetaGraphTest(tf.test.TestCase): tf.truncated_normal([28, 128], stddev=1.0 / math.sqrt(float(28))), name="weights") - biases = tf.Variable(tf.zeros([128]), + # The use of control_flow_ops.cond here is purely for adding test coverage + # the save and restore of control flow context (which doesn't make any + # sense here from a machine learning perspective). The typical biases is + # a simple Variable without the conditions. + biases = tf.Variable(control_flow_ops.cond(tf.less(random.random(), 0.5), + lambda: tf.ones([128]), + lambda: tf.zeros([128])), name="biases") hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases) # Hidden 2 @@ -1427,8 +1441,19 @@ class MetaGraphTest(tf.test.TestCase): tf.truncated_normal([128, 32], stddev=1.0 / math.sqrt(float(128))), name="weights") - biases = tf.Variable(tf.zeros([32]), - name="biases") + + # The use of control_flow_ops.while_loop here is purely for adding test + # coverage the save and restore of control flow context (which doesn't + # make any sense here from a machine learning perspective). The typical + # biases is a simple Variable without the conditions. + def loop_cond(it, _): + return it < 2 + def loop_body(it, biases): + biases += tf.constant(0.1, shape=[32]) + return it + 1, biases + _, biases = control_flow_ops.while_loop( + loop_cond, loop_body, + [tf.constant(0), tf.Variable(tf.zeros([32]))]) hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) # Linear with tf.name_scope("softmax_linear"): @@ -1456,6 +1481,7 @@ class MetaGraphTest(tf.test.TestCase): def _testGraphExtensionRestore(self): test_dir = os.path.join(self.get_temp_dir(), "graph_extension") filename = os.path.join(test_dir, "metafile") + train_filename = os.path.join(test_dir, "train_metafile") saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") with self.test_session(graph=tf.Graph()) as sess: # Restores from MetaGraphDef. @@ -1484,11 +1510,30 @@ class MetaGraphTest(tf.test.TestCase): # Runs train_op. train_op = optimizer.minimize(loss) + tf.add_to_collection("train_op", train_op) + + # Runs train_op. + sess.run(train_op) + + # Generates MetaGraphDef. + tf.train.export_meta_graph(train_filename) + + def _testRestoreFromTrainGraphWithControlContext(self): + test_dir = os.path.join(self.get_temp_dir(), "graph_extension") + train_filename = os.path.join(test_dir, "train_metafile") + saver0_ckpt = os.path.join(test_dir, "saver0.ckpt") + with self.test_session(graph=tf.Graph()) as sess: + # Restores from MetaGraphDef. + new_saver = tf.train.import_meta_graph(train_filename) + # Restores from checkpoint. + new_saver.restore(sess, saver0_ckpt) + train_op = tf.get_collection("train_op")[0] sess.run(train_op) def testGraphExtension(self): self._testGraphExtensionSave() self._testGraphExtensionRestore() + self._testRestoreFromTrainGraphWithControlContext() def testStrippedOpListDef(self): with self.test_session(): -- cgit v1.2.3