diff options
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r-- | tensorflow/python/ops/variables.py | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 89a19f377d..32545e4eb3 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -650,6 +650,46 @@ class Variable(object): """ return state_ops.count_up_to(self._variable, limit=limit) + def load(self, value, session=None): + """Load new value into this variable + + Writes new value to variable's memory. Doesn't add ops to the graph. + + This convenience method requires a session where the graph containing this + variable has been launched. If no session is passed, the default session is + used. See the [Session class](../../api_docs/python/client.md#Session) for + more information on launching a graph and on sessions. + + ```python + v = tf.Variable([1, 2]) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + # Usage passing the session explicitly. + v.load([2, 3], sess) + print(v.eval(sess)) # prints [2 3] + # Usage with the default session. The 'with' block + # above makes 'sess' the default session. + v.load([3, 4], sess) + print(v.eval()) # prints [3 4] + ``` + + Args: + value: New variable value + session: The session to use to evaluate this variable. If + none, the default session is used. + + Raises: + ValueError: Session is not passed and no default session + """ + session = session or ops.get_default_session() + if session is None: + raise ValueError( + "Either session argument should be provided or default session " + "should be established") + session.run(self._initializer_op, {self._initializer_op.inputs[1]: value}) + # Conversion to tensor. @staticmethod def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name @@ -1070,7 +1110,7 @@ def global_variables(): return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) -@deprecated("2016-03-02", "Please use tf.global_variables instead.") +@deprecated("2017-03-02", "Please use tf.global_variables instead.") def all_variables(): """See `tf.global_variables`.""" return global_variables() |