aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variables.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r--tensorflow/python/ops/variables.py42
1 files changed, 41 insertions, 1 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 89a19f377d..32545e4eb3 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -650,6 +650,46 @@ class Variable(object):
"""
return state_ops.count_up_to(self._variable, limit=limit)
+ def load(self, value, session=None):
+ """Load new value into this variable
+
+ Writes new value to variable's memory. Doesn't add ops to the graph.
+
+ This convenience method requires a session where the graph containing this
+ variable has been launched. If no session is passed, the default session is
+ used. See the [Session class](../../api_docs/python/client.md#Session) for
+ more information on launching a graph and on sessions.
+
+ ```python
+ v = tf.Variable([1, 2])
+ init = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ # Usage passing the session explicitly.
+ v.load([2, 3], sess)
+ print(v.eval(sess)) # prints [2 3]
+ # Usage with the default session. The 'with' block
+ # above makes 'sess' the default session.
+ v.load([3, 4], sess)
+ print(v.eval()) # prints [3 4]
+ ```
+
+ Args:
+ value: New variable value
+ session: The session to use to evaluate this variable. If
+ none, the default session is used.
+
+ Raises:
+ ValueError: Session is not passed and no default session
+ """
+ session = session or ops.get_default_session()
+ if session is None:
+ raise ValueError(
+ "Either session argument should be provided or default session "
+ "should be established")
+ session.run(self._initializer_op, {self._initializer_op.inputs[1]: value})
+
# Conversion to tensor.
@staticmethod
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name
@@ -1070,7 +1110,7 @@ def global_variables():
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
-@deprecated("2016-03-02", "Please use tf.global_variables instead.")
+@deprecated("2017-03-02", "Please use tf.global_variables instead.")
def all_variables():
"""See `tf.global_variables`."""
return global_variables()