diff options
author | 2016-02-24 14:50:49 -0800 | |
---|---|---|
committer | 2016-02-24 15:36:11 -0800 | |
commit | 2861cc1d239932a5dfd745c86f30acc8bb947b6d (patch) | |
tree | 4305aeb3a6a76e525d96ea47c4459be438e2f8c0 | |
parent | 497606904be87f7a4078ed7ee0784afaa094b258 (diff) |
Surface control_flow_ops.case to public. Update docs. Add unit tests.
Change: 115496194
-rw-r--r-- | tensorflow/g3doc/api_docs/python/array_ops.md | 114 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/contrib.layers.md | 31 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/control_flow_ops.md | 78 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/framework.md | 9 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/image.md | 33 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/index.md | 7 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/math_ops.md | 18 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/state_ops.md | 50 | ||||
-rw-r--r-- | tensorflow/g3doc/api_docs/python/train.md | 72 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 44 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 7 | ||||
-rw-r--r-- | tensorflow/python/ops/standard_ops.py | 1 |
12 files changed, 421 insertions, 43 deletions
diff --git a/tensorflow/g3doc/api_docs/python/array_ops.md b/tensorflow/g3doc/api_docs/python/array_ops.md index d532731f5c..c634ee799a 100644 --- a/tensorflow/g3doc/api_docs/python/array_ops.md +++ b/tensorflow/g3doc/api_docs/python/array_ops.md @@ -1285,6 +1285,120 @@ boolean_mask(tensor, mask) ==> [[1, 2], [5, 6]] ``` +- - - + +### `tf.one_hot(indices, depth, on_value, off_value, axis=None, name=None)` {#one_hot} + +Returns a one-hot tensor. + +The locations represented by indices in `indices` take value `on_value`, +while all other locations take value `off_value`. + +If the input `indices` is rank `N`, the output will have rank `N+1`, +The new axis is created at dimension `axis` (default: the new axis is +appended at the end). + +If `indices` is a scalar the output shape will be a vector of length `depth`. + +If `indices` is a vector of length `features`, the output shape will be: +``` + features x depth if axis == -1 + depth x features if axis == 0 +``` + +If `indices` is a matrix (batch) with shape `[batch, features]`, +the output shape will be: +``` + batch x features x depth if axis == -1 + batch x depth x features if axis == 1 + depth x batch x features if axis == 0 +``` + + +Examples +========= + +Suppose that + +``` + indices = [0, 2, -1, 1] + depth = 3 + on_value = 5.0 + off_value = 0.0 + axis = -1 +``` + +Then output is `[4 x 3]`: + + ```output = + [5.0 0.0 0.0] // one_hot(0) + [0.0 0.0 5.0] // one_hot(2) + [0.0 0.0 0.0] // one_hot(-1) + [0.0 5.0 0.0] // one_hot(1) + ``` + +Suppose that + +``` + indices = [0, 2, -1, 1] + depth = 3 + on_value = 0.0 + off_value = 3.0 + axis = 0 +``` + +Then output is `[3 x 4]`: + + ```output = + [0.0 3.0 3.0 3.0] + [3.0 3.0 3.0 0.0] + [3.0 3.0 3.0 3.0] + [3.0 0.0 3.0 3.0] + // ^ one_hot(0) + // ^ one_hot(2) + // ^ one_hot(-1) + // ^ one_hot(1) + ``` +Suppose that + +``` + indices = [[0, 2], [1, -1]] + depth = 3 + on_value = 1.0 + off_value = 0.0 + axis = -1 +``` + +Then output is `[2 x 2 x 3]`: + + ```output = + [ + [1.0, 0.0, 0.0] // one_hot(0) + [0.0, 0.0, 1.0] // one_hot(2) + ][ + [0.0, 1.0, 0.0] // one_hot(1) + [0.0, 0.0, 0.0] // one_hot(-1) + ]``` + +##### Args: + + +* <b>`indices`</b>: A `Tensor` of type `int64`. A tensor of indices. +* <b>`depth`</b>: A `Tensor` of type `int32`. + A scalar defining the depth of the one hot dimension. +* <b>`on_value`</b>: A `Tensor`. + A scalar defining the value to fill in output when `indices[j] = i`. +* <b>`off_value`</b>: A `Tensor`. Must have the same type as `on_value`. + A scalar defining the value to fill in output when `indices[j] != i`. +* <b>`axis`</b>: An optional `int`. Defaults to `-1`. + The axis to fill (default: -1, a new inner-most axis). +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A `Tensor`. Has the same type as `on_value`. The one-hot tensor. + + ## Other Functions and Classes - - - diff --git a/tensorflow/g3doc/api_docs/python/contrib.layers.md b/tensorflow/g3doc/api_docs/python/contrib.layers.md index d9351e47ed..78bc28d8d5 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.layers.md +++ b/tensorflow/g3doc/api_docs/python/contrib.layers.md @@ -301,7 +301,7 @@ activation. - - - -### `tf.contrib.layers.summarize_tensor(tensor)` {#summarize_tensor} +### `tf.contrib.layers.summarize_tensor(tensor, tag=None)` {#summarize_tensor} Summarize a tensor using a suitable summary type. @@ -313,6 +313,7 @@ other tensors, `histogram_summary` is used. * <b>`tensor`</b>: The tensor to summarize +* <b>`tag`</b>: The tag to use, if None then use tensor's op's name. ##### Returns: @@ -377,3 +378,31 @@ be `dtypes.float32` or `dtypes.float64`. If neither `tensors` nor float. +- - - + +### `tf.contrib.layers.assert_scalar_int(tensor)` {#assert_scalar_int} + +Assert `tensor` is 0-D, of type `tf.int32` or `tf.int64`. + +##### Args: + + +* <b>`tensor`</b>: Tensor to test. + +##### Returns: + + `tensor`, for chaining. + +##### Raises: + + +* <b>`ValueError`</b>: if `tensor` is not 0-D, of type `tf.int32` or `tf.int64`. + + +- - - + +### `tf.contrib.layers.is_numeric_tensor(tensor)` {#is_numeric_tensor} + + + + diff --git a/tensorflow/g3doc/api_docs/python/control_flow_ops.md b/tensorflow/g3doc/api_docs/python/control_flow_ops.md index c4b4d6e4f4..dded55f87c 100644 --- a/tensorflow/g3doc/api_docs/python/control_flow_ops.md +++ b/tensorflow/g3doc/api_docs/python/control_flow_ops.md @@ -182,6 +182,84 @@ the same non-zero number and type of outputs. ``` +- - - + +### `tf.case(pred_fn_pairs, default, exclusive=False, name='case')` {#case} + +Create a case operation. + +The `pred_fn_pairs` parameter is a dict or list of pairs of size N. +Each pair contains a boolean scalar tensor and a python callable that +creates the tensors to be returned if the boolean evaluates to True. `default` +is a callable generating a list of tensors. All the callables in +`pred_fn_pairs` as well as `default` should return the same number and types +of tensors. + +If `exclusive==True`, all predicates are evaluated, and a logging operation +with an error is returned if more than one of the predicates evaluates to +True. If `exclusive==False`, execution stops are the first predicate which +evaluates to True, and the tensors generated by the corresponding function +are returned immediately. If none of the predicates evaluate to True, this +operation returns the tensors generated by `default`. + +Example 1: + Pseudocode: + ``` + if (x < y) return 17; + else return 23; + ``` + + Expressions: + ``` + f1 = lambda: tf.constant(17) + f2 = lambda: tf.constant(23) + r = case([(tf.less(x, y), f1)], default=f2) + ``` + +Example 2: + Pseudocode: + ``` + if (x < y && x > z) raise OpError("Only one predicate may evaluate true"); + if (x < y) return 17; + else if (x > z) return 23; + else return -1; + ``` + + Expressions: + ``` + x = tf.constant(0) + y = tf.constant(1) + z = tf.constant(2) + def f1(): return tf.constant(17) + def f2(): return tf.constant(23) + def f3(): return tf.constant(-1) + r = case({tf.less(x, y): f1, tf.greater(x, z): f2}, + default=f3, exclusive=True) + ``` + +##### Args: + + +* <b>`pred_fn_pairs`</b>: Dict or list of pairs of a boolean scalar tensor and a + callable which returns a list of tensors. +* <b>`default`</b>: A callable that returns a list of tensors. +* <b>`exclusive`</b>: True iff more than one predicate is allowed to evaluate to True. +* <b>`name`</b>: A name for this operation (optional). + +##### Returns: + + The tensors returned by the first pair whose predicate evaluated to True, or + those returned by `default` if none does. + +##### Raises: + + +* <b>`TypeError`</b>: If `pred_fn_pairs` is not a list/dictionary. +* <b>`TypeError`</b>: If `pred_fn_pairs` is a list but does not contain 2-tuples. +* <b>`TypeError`</b>: If `fns[i]` is not callable for any i, or `default` is not + callable. + + ## Logical Operators diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md index 590355bdab..52d2560dd3 100644 --- a/tensorflow/g3doc/api_docs/python/framework.md +++ b/tensorflow/g3doc/api_docs/python/framework.md @@ -523,7 +523,7 @@ This method may be called concurrently from multiple threads. - - - -#### `tf.Graph.unique_name(name)` {#Graph.unique_name} +#### `tf.Graph.unique_name(name, mark_as_used=True)` {#Graph.unique_name} Return a unique operation name for `name`. @@ -537,10 +537,17 @@ Operation names are displayed in error messages reported by the TensorFlow runtime, and in various visualization tools such as TensorBoard. +If `mark_as_used` is set to `True`, which is the default, a new +unique name is created and marked as in use. If it's set to `False`, +the unique name is returned without actually being marked as used. +This is useful when the caller simply wants to know what the name +to be created will be. + ##### Args: * <b>`name`</b>: The name for an operation. +* <b>`mark_as_used`</b>: Whether to mark this name as being used. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/image.md b/tensorflow/g3doc/api_docs/python/image.md index 5998b591a0..622edac7a9 100644 --- a/tensorflow/g3doc/api_docs/python/image.md +++ b/tensorflow/g3doc/api_docs/python/image.md @@ -400,6 +400,39 @@ dimension. - - - +### `tf.image.central_crop(image, central_fraction)` {#central_crop} + +Crop the central region of the image. + +Remove the outer parts of an image but retain the central region of the image +along each dimension. If we specify central_fraction = 0.5, this function +returns the region marked with "X" in the below diagram. + + -------- + | | + | XXXX | + | XXXX | + | | where "X" is the central 50% of the image. + -------- + +##### Args: + + +* <b>`image`</b>: 3-D float Tensor of shape [height, width, depth] +* <b>`central_fraction`</b>: float (0, 1], fraction of size to crop + +##### Raises: + + +* <b>`ValueError`</b>: if central_crop_fraction is not within (0, 1]. + +##### Returns: + + 3-D float Tensor + + +- - - + ### `tf.image.pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width)` {#pad_to_bounding_box} Pad `image` with zeros to the specified `height` and `width`. diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 82f7fa02fc..c85e25bbe8 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -92,6 +92,7 @@ * [`dynamic_stitch`](../../api_docs/python/array_ops.md#dynamic_stitch) * [`expand_dims`](../../api_docs/python/array_ops.md#expand_dims) * [`gather`](../../api_docs/python/array_ops.md#gather) + * [`one_hot`](../../api_docs/python/array_ops.md#one_hot) * [`pack`](../../api_docs/python/array_ops.md#pack) * [`pad`](../../api_docs/python/array_ops.md#pad) * [`rank`](../../api_docs/python/array_ops.md#rank) @@ -193,6 +194,7 @@ * [`sparse_segment_sum`](../../api_docs/python/math_ops.md#sparse_segment_sum) * [`sqrt`](../../api_docs/python/math_ops.md#sqrt) * [`square`](../../api_docs/python/math_ops.md#square) + * [`squared_difference`](../../api_docs/python/math_ops.md#squared_difference) * [`sub`](../../api_docs/python/math_ops.md#sub) * [`transpose`](../../api_docs/python/math_ops.md#transpose) * [`truediv`](../../api_docs/python/math_ops.md#truediv) @@ -203,6 +205,7 @@ * **[Control Flow](../../api_docs/python/control_flow_ops.md)**: * [`add_check_numerics_ops`](../../api_docs/python/control_flow_ops.md#add_check_numerics_ops) * [`Assert`](../../api_docs/python/control_flow_ops.md#Assert) + * [`case`](../../api_docs/python/control_flow_ops.md#case) * [`check_numerics`](../../api_docs/python/control_flow_ops.md#check_numerics) * [`cond`](../../api_docs/python/control_flow_ops.md#cond) * [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to) @@ -233,6 +236,7 @@ * [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast) * [`adjust_hue`](../../api_docs/python/image.md#adjust_hue) * [`adjust_saturation`](../../api_docs/python/image.md#adjust_saturation) + * [`central_crop`](../../api_docs/python/image.md#central_crop) * [`convert_image_dtype`](../../api_docs/python/image.md#convert_image_dtype) * [`crop_to_bounding_box`](../../api_docs/python/image.md#crop_to_bounding_box) * [`decode_jpeg`](../../api_docs/python/image.md#decode_jpeg) @@ -316,6 +320,7 @@ * **[Neural Network](../../api_docs/python/nn.md)**: * [`avg_pool`](../../api_docs/python/nn.md#avg_pool) + * [`batch_normalization`](../../api_docs/python/nn.md#batch_normalization) * [`bias_add`](../../api_docs/python/nn.md#bias_add) * [`compute_accidental_hits`](../../api_docs/python/nn.md#compute_accidental_hits) * [`conv2d`](../../api_docs/python/nn.md#conv2d) @@ -422,8 +427,10 @@ * **[Layers (contrib)](../../api_docs/python/contrib.layers.md)**: * [`assert_same_float_dtype`](../../api_docs/python/contrib.layers.md#assert_same_float_dtype) + * [`assert_scalar_int`](../../api_docs/python/contrib.layers.md#assert_scalar_int) * [`convolution2d`](../../api_docs/python/contrib.layers.md#convolution2d) * [`fully_connected`](../../api_docs/python/contrib.layers.md#fully_connected) + * [`is_numeric_tensor`](../../api_docs/python/contrib.layers.md#is_numeric_tensor) * [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer) * [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer) * [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation) diff --git a/tensorflow/g3doc/api_docs/python/math_ops.md b/tensorflow/g3doc/api_docs/python/math_ops.md index cdbb54aa6a..ca2ccbd232 100644 --- a/tensorflow/g3doc/api_docs/python/math_ops.md +++ b/tensorflow/g3doc/api_docs/python/math_ops.md @@ -628,6 +628,24 @@ Computes complementary error function of `x` element-wise. the return type is `quint8`. +- - - + +### `tf.squared_difference(x, y, name=None)` {#squared_difference} + +Returns (x - y)(x - y) element-wise. + +##### Args: + + +* <b>`x`</b>: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `complex64`, `int64`. +* <b>`y`</b>: A `Tensor`. Must have the same type as `x`. +* <b>`name`</b>: A name for the operation (optional). + +##### Returns: + + A `Tensor`. Has the same type as `x`. + + ## Matrix Math Functions diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md index 01988995bb..860087939c 100644 --- a/tensorflow/g3doc/api_docs/python/state_ops.md +++ b/tensorflow/g3doc/api_docs/python/state_ops.md @@ -104,7 +104,7 @@ Creating a variable. - - - -#### `tf.Variable.__init__(initial_value=None, trainable=True, collections=None, validate_shape=True, name=None, variable_def=None)` {#Variable.__init__} +#### `tf.Variable.__init__(initial_value=None, trainable=True, collections=None, validate_shape=True, caching_device=None, name=None, variable_def=None)` {#Variable.__init__} Creates a new variable with value `initial_value`. @@ -131,6 +131,11 @@ variable to its initial value. * <b>`validate_shape`</b>: If `False`, allows the variable to be initialized with a value of unknown shape. If `True`, the default, the shape of `initial_value` must be known. +* <b>`caching_device`</b>: Optional device string describing where the Variable + should be cached for reading. Defaults to the Variable's device. + If not `None`, caches on another device. Typical use is to cache + on the device where the Ops using the Variable reside, to deduplicate + copying through `Switch` and other conditional statements. * <b>`name`</b>: Optional name for the variable. Defaults to `'Variable'` and gets uniquified automatically. * <b>`variable_def`</b>: `VariableDef` protocol buffer. If not `None`, recreates @@ -389,7 +394,7 @@ The `Operation` of this variable. #### `tf.Variable.from_proto(variable_def)` {#Variable.from_proto} - +Returns a `Variable` object created from `variable_def`. - - - @@ -742,7 +747,7 @@ path can be passed directly to a call to `restore()`. kept in the same directory as the checkpoint files, is automatically managed by the saver to keep track of recent checkpoints. Defaults to 'checkpoint'. -* <b>`meta_graph_suffix`</b>: Suffix for MetaGraphDef file. Defaults to 'meta'. +* <b>`meta_graph_suffix`</b>: Suffix for `MetaGraphDef` file. Defaults to 'meta'. ##### Returns: @@ -848,7 +853,7 @@ Writes `MetaGraphDef` to save_path/filename. #### `tf.train.Saver.from_proto(saver_def)` {#Saver.from_proto} - +Returns a `Saver` object created from `saver_def`. - - - @@ -873,7 +878,11 @@ Sets the list of old checkpoint filenames and timestamps. #### `tf.train.Saver.to_proto()` {#Saver.to_proto} -Returns a `SaverDef` protocol buffer. +Converts this `Saver` to a `SaverDef` protocol buffer. + +##### Returns: + + A `SaverDef` protocol buffer. @@ -1022,17 +1031,26 @@ Attributes: initializer: default initializer passed to get_variable. regularizer: default regularizer passed to get_variable. reuse: Boolean or None, setting the reuse in get_variable. + caching_device: string, callable, or None: the caching device passed to + get_variable. name_scope: The name passed to tf.name_scope. - - - -#### `tf.VariableScope.__init__(reuse, name='', initializer=None, regularizer=None, name_scope='')` {#VariableScope.__init__} +#### `tf.VariableScope.__init__(reuse, name='', initializer=None, regularizer=None, caching_device=None, name_scope='')` {#VariableScope.__init__} Creates a new VariableScope with the given properties. - - - -#### `tf.VariableScope.get_variable(var_store, name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None)` {#VariableScope.get_variable} +#### `tf.VariableScope.caching_device` {#VariableScope.caching_device} + + + + +- - - + +#### `tf.VariableScope.get_variable(var_store, name, shape=None, dtype=tf.float32, initializer=None, regularizer=None, trainable=True, collections=None, caching_device=None)` {#VariableScope.get_variable} Gets an existing variable with this name or create a new one. @@ -1074,6 +1092,13 @@ Reuse variables in this scope. - - - +#### `tf.VariableScope.set_caching_device(caching_device)` {#VariableScope.set_caching_device} + +Set caching_device for this scope. + + +- - - + #### `tf.VariableScope.set_initializer(initializer)` {#VariableScope.set_initializer} Set initializer for this scope. @@ -1089,7 +1114,7 @@ Set regularizer for this scope. - - - -### `tf.variable_scope(name_or_scope, reuse=None, initializer=None, regularizer=None)` {#variable_scope} +### `tf.variable_scope(name_or_scope, reuse=None, initializer=None, regularizer=None, caching_device=None)` {#variable_scope} Returns a context for variable scope. @@ -1157,6 +1182,7 @@ then all its sub-scopes become reusing as well. well as all sub-scopes; if `None`, we just inherit the parent scope reuse. * <b>`initializer`</b>: default initializer for variables within this scope. * <b>`regularizer`</b>: default regularizer for variables within this scope. +* <b>`caching_device`</b>: default caching device for variables within this scope. ##### Returns: @@ -1172,7 +1198,7 @@ then all its sub-scopes become reusing as well. - - - -### `tf.variable_op_scope(values, name, default_name, initializer=None, regularizer=None)` {#variable_op_scope} +### `tf.variable_op_scope(values, name, default_name, initializer=None, regularizer=None, caching_device=None)` {#variable_op_scope} Returns a context manager for defining an op that creates variables. @@ -1208,8 +1234,10 @@ def my_op_with_vars(a, b, name=None): uniquified in the variable scope. * <b>`default_name`</b>: The default name to use if the `name` argument is `None`, this name will be uniquified. -* <b>`initializer`</b>: A default initializer to pass to variable scope. -* <b>`regularizer`</b>: default regularizer for variables within this scope. +* <b>`initializer`</b>: The default initializer to pass to variable scope. +* <b>`regularizer`</b>: The default regularizer for variables within this scope. +* <b>`caching_device`</b>: The default caching device for variables within this scope. + ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md index 3e78ad3eac..ba5f270eca 100644 --- a/tensorflow/g3doc/api_docs/python/train.md +++ b/tensorflow/g3doc/api_docs/python/train.md @@ -410,36 +410,15 @@ current good choice is 1.0 or 0.1. Optimizer that implements the FTRL algorithm. +See this [paper]( +https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). + - - - #### `tf.train.FtrlOptimizer.__init__(learning_rate, learning_rate_power=-0.5, initial_accumulator_value=0.1, l1_regularization_strength=0.0, l2_regularization_strength=0.0, use_locking=False, name='Ftrl')` {#FtrlOptimizer.__init__} Construct a new FTRL optimizer. -The Ftrl-proximal algorithm, abbreviated for Follow-the-regularized-leader, -is described in the paper [Ad Click Prediction: a View from the Trenches]( -https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf). - -It can give a good performance vs. sparsity tradeoff. - -Ftrl-proximal uses its own global base learning rate and can behave like -Adagrad with `learning_rate_power=-0.5`, or like gradient descent with -`learning_rate_power=0.0`. - -The effective learning rate is adjusted per parameter, relative to this -base learning rate as: - -``` -effective_learning_rate_i = (learning_rate / - pow(k + summed_squared_gradients_for_i, learning_rate_power)); -``` - -where k is the small constant `initial_accumulator_value`. - -Note that the real regularization coefficient of `|w|^2` for objective -function is `1 / lambda_2` if specifying `l2 = lambda_2` as argument when -using this function. - ##### Args: @@ -1442,7 +1421,7 @@ depending on whether or not a `Coordinator` was passed to #### `tf.train.QueueRunner.from_proto(queue_runner_def)` {#QueueRunner.from_proto} - +Returns a `QueueRunner` object created from `queue_runner_def`. - - - @@ -1936,7 +1915,7 @@ global_step: 10 ##### Args: -* <b>`sess`</b>: A brain `Session` object. +* <b>`sess`</b>: A TensorFlow `Session` object. * <b>`global_step_tensor`</b>: `Tensor` or the `name` of the operation that contains the global step. @@ -2237,9 +2216,12 @@ Generates a checkpoint state proto. Recreates a Graph saved in a `MetaGraphDef` proto. -This function reads from a file containing a `MetaGraphDef` proto, -adds all the nodes from the graph_def proto to the current graph, -recreates all the collections, and returns a saver from saver_def. +This function takes a `MetaGraphDef` protocol buffer as input. If +the argument is a file containing a `MetaGraphDef` protocol buffer , +it constructs a protocol buffer from the file content. The function +then adds all the nodes from the `graph_def` field to the +current graph, recreates all the collections, and returns a saver +constructed from the `saver_def` field. In combination with `export_meta_graph()`, this function can be used to @@ -2250,6 +2232,38 @@ In combination with `export_meta_graph()`, this function can be used to * Run inference from a saved graph and checkpoints. +```Python +... +# Create a saver. +saver = tf.train.Saver(...variables...) +# Remember the training_op we want to run by adding it to a collection. +tf.add_to_collection('train_op', train_op) +sess = tf.Session() +for step in xrange(1000000): + sess.run(train_op) + if step % 1000 == 0: + # Saves checkpoint, which by default also exports a meta_graph + # named 'my-model-global_step.meta'. + saver.save(sess, 'my-model', global_step=step) +``` + +Later we can continue training from this saved `meta_graph` without building +the model from scratch. + +```Python +with tf.Session() as sess: + new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') + new_saver.restore(sess, 'my-save-dir/my-model-10000') + # tf.get_collection() retrurns a list. In this example we only want the + # first one. + train_op = tf.get_collection('train_op')[0] + for step in xrange(1000000): + sess.run(train_op) +``` + +NOTE: Restarting training from saved `meta_graph` only works if the +device assignments have not changed. + ##### Args: diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 6b84896bfa..c501682cc5 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1232,6 +1232,50 @@ class ControlFlowTest(tf.test.TestCase): self.assertAllClose(4.0, i.eval(feed_dict={d: 1})) self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) + def testCase(self): + with self.test_session(): + x = tf.constant(1) + y = tf.constant(2) + z = tf.constant(3) + f1 = lambda: tf.constant(17) + f2 = lambda: tf.constant(23) + f3 = lambda: tf.constant(-1) + + r1 = tf.case({x < y: f1, x > z: f2}, default=f3, exclusive=True) + self.assertAllEqual(r1.eval(), 17) + + r2 = tf.case([(y > z, f1), (y > x, f2)], default=f3) + self.assertAllEqual(r2.eval(), 23) + + # Duplicate events can happen, first one is selected + r3 = tf.case([(x < y, f1), (x < y, f2)], default=f3) + self.assertAllEqual(r3.eval(), 17) + + # Duplicate events cause an error if exclusive = True + r4 = tf.case([(x < y, f1), (x < y, f2)], default=f3, exclusive=True) + with self.assertRaisesOpError( + "More than one condition evaluated as True but exclusive=True."): + r4.eval() + + # Check that the default is called if none of the others are + r5 = tf.case({x > y: f1}, default=f3) + self.assertAllEqual(r5.eval(), -1) + + ran_once = [False, False, False] + + def break_run_twice(ix): + def _break(): + assert not ran_once[ix] + ran_once[ix] = True + return tf.constant(ix) + return _break + + # Should not fail - each conditional gets called exactly once + r6 = tf.case([(x < y, break_run_twice(0)), (x > y, break_run_twice(1))], + default=break_run_twice(2)) + + self.assertAllEqual(r6.eval(), 0) + def testOneOpCond(self): with self.test_session(): v = tf.Variable(0) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 30918d84c9..112c1d9907 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -24,6 +24,7 @@ the execution of operations and add conditional dependencies to your graph. @@no_op @@count_up_to @@cond +@@case ## Logical Operators @@ -82,6 +83,7 @@ from tensorflow.python.ops import constant_op from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops +from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops # pylint: disable=wildcard-import,undefined-variable @@ -1974,6 +1976,9 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): Expressions: ``` + x = tf.constant(0) + y = tf.constant(1) + z = tf.constant(2) def f1(): return tf.constant(17) def f2(): return tf.constant(23) def f3(): return tf.constant(-1) @@ -2050,7 +2055,7 @@ def case(pred_fn_pairs, default, exclusive=False, name="case"): # and prev_case_seq will loop from case_sequence[0] to case_sequence[-1] if exclusive: # TODO(ebrevdo): Add Where() for DT_BOOL, replace with Size(Where(preds)) - preds_c = array_ops.concat(0, preds, name="preds_c") + preds_c = array_ops.pack(preds, name="preds_c") num_true_conditions = math_ops.reduce_sum( math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") at_most_one_true_condition = math_ops.less( diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 57ff78cdde..b8e6b26468 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -36,6 +36,7 @@ from tensorflow.python.ops.control_flow_ops import group from tensorflow.python.ops.control_flow_ops import no_op from tensorflow.python.ops.control_flow_ops import tuple from tensorflow.python.ops.control_flow_ops import cond +from tensorflow.python.ops.control_flow_ops import case from tensorflow.python.ops.data_flow_ops import * from tensorflow.python.ops.gradients import * from tensorflow.python.ops.init_ops import * |