From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/python/ops/state_ops.py | 189 +++++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) create mode 100644 tensorflow/python/ops/state_ops.py (limited to 'tensorflow/python/ops/state_ops.py') diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py new file mode 100644 index 0000000000..1c8f38b94c --- /dev/null +++ b/tensorflow/python/ops/state_ops.py @@ -0,0 +1,189 @@ +"""## Variables + +@@Variable + +## Variable helper functions + +TensorFlow provides a set of functions to help manage the set of variables +collected in the graph. + +@@all_variables +@@trainable_variables + +@@initialize_all_variables +@@initialize_variables +@@assert_variables_initialized + +## Saving and Restoring Variables. + +@@Saver + +@@latest_checkpoint + +@@get_checkpoint_state +@@update_checkpoint_state + +## Sharing Variables + +TensorFlow provides several classes and operations that you can use to +create variables contingent on certain conditions. + +@@get_variable +@@get_variable_scope +@@variable_scope + +@@constant_initializer +@@random_normal_initializer +@@truncated_normal_initializer +@@random_uniform_initializer +@@uniform_unit_scaling_initializer +@@zeros_initializer + +## Sparse Variable Updates + +The sparse update ops modify a subset of the entries in a dense `Variable`, +either overwriting the entries or adding / subtracting a delta. These are +useful for training embedding models and similar lookup-based networks, since +only a small subset of embedding vectors change in any given step. + +Since a sparse update of a large tensor may be generated automatically during +gradient computation (as in the gradient of [`tf.gather`](array_ops.md#gather)), +an [`IndexedSlices`](#IndexedSlices) class is provided that encapsulates a set +of sparse indices and values. `IndexedSlices` objects are detected and handled +automatically by the optimizers in most cases. + +@@scatter_update +@@scatter_add +@@scatter_sub +@@sparse_mask +@@IndexedSlices +""" + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import common_shapes +from tensorflow.python.ops import gen_state_ops +# pylint: disable=wildcard-import,undefined-variable +from tensorflow.python.ops.gen_state_ops import * + + +# pylint: disable=protected-access +def variable_op(shape, dtype, name="Variable", set_shape=True, container="", + shared_name=""): + """Create a variable Operation. + + See also variables.Variable. + + Args: + shape: The shape of the tensor managed by this variable + dtype: The underlying type of the tensor values. + name: optional name to use for the variable op. + set_shape: If True, set the shape property of the returned Tensor to + the shape argument. + container: An optional string. Defaults to "". + If non-empty, this variable is placed in the given container. + Otherwise, a default container is used. + shared_name: An optional string. Defaults to "". + If non-empty, this variable is named in the given bucket + with this shared_name. Otherwise, the node name is used instead. + + Returns: + A variable tensor. + """ + ret = gen_state_ops._variable(shape=shape, dtype=dtype, name=name, + container=container, shared_name=shared_name) + # TODO(mrry): Move this to where it is used, so we can get rid of this op + # wrapper? + if set_shape: + ret.set_shape(shape) + return ret + + +# NOTE(mrry): Shapes are conditionally set in the Python wrapper. +ops.RegisterShape("Variable")(common_shapes.unknown_shape) + + +@ops.RegisterShape("TemporaryVariable") +def _TemporaryVariableShape(op): + """Shape function for the TemporaryVariable op.""" + shape = tensor_util.TensorShapeProtoToList(op.get_attr("shape")) + return [tensor_shape.TensorShape(shape)] + + +@ops.RegisterShape("DestroyTemporaryVariable") +def _DestroyTemporaryVariableShape(op): + """Shape function for the DestroyTemporaryVariable op.""" + return [op.inputs[0].get_shape()] + + +def init_variable(v, init, name="init"): + """Initializes variable with "init". + + This op does the following: + if init is a Tensor, v = init + if callable(init): v = init(VariableShape(v), v.dtype) + + Args: + v: Variable to initialize + init: Tensor to assign to v, + Or an object convertible to Tensor e.g. nparray, + Or an Initializer that generates a tensor given the shape and type of v. + An "Initializer" is a callable that returns a tensor that "v" should be + set to. It will be called as init(shape, dtype). + name: Optional name for the op. + + Returns: + The operation that initializes v. + """ + with ops.op_scope([v, init], None, v.op.name + "/"): + with ops.name_scope(name) as scope: + with ops.device(v.device or ops.get_default_graph().get_default_device()): + if callable(init): + assert v.get_shape().is_fully_defined(), "Variable shape unknown." + # TODO(mrry): Convert to v.shape when the property and + # accessor are reconciled (and all initializers support + # tf.TensorShape objects). + value = init(v.get_shape().as_list(), v.dtype.base_dtype) + value = ops.convert_to_tensor(value, name="value") + return assign(v, value, name=scope) + else: + init = ops.convert_to_tensor(init, name="init") + return assign(v, init, name=scope) + + +@ops.RegisterShape("Assign") +def _AssignShape(op): + """Shape function for the Assign op.""" + if op.get_attr("validate_shape"): + # NOTE(mrry): Return a known shape here. This makes it awkward to + # chain a validated-shape assignment and a reshaping assignment, + # but that is a sufficiently niche case that supporting it does + # not seem worthwhile. + return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())] + return [op.inputs[1].get_shape()] + + +@ops.RegisterShape("AssignAdd") +@ops.RegisterShape("AssignSub") +def _AssignUpdateShape(op): + """Shape function for the AssignAdd and AssignSub dense update ops.""" + return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())] + + +@ops.RegisterShape("CountUpTo") +def _CountUpToShape(op): + """Shape function for the CountUpTo op.""" + return [op.inputs[0].get_shape().merge_with(tensor_shape.scalar())] + + +@ops.RegisterShape("ScatterAdd") +@ops.RegisterShape("ScatterSub") +@ops.RegisterShape("ScatterUpdate") +def _ScatterUpdateShape(op): + """Shape function for the sparse update ops.""" + var_shape = op.inputs[0].get_shape() + indices_shape = op.inputs[1].get_shape() + unused_updates_shape = op.inputs[2].get_shape().merge_with( + indices_shape.concatenate(var_shape[1:])) + return [var_shape] -- cgit v1.2.3