aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-07-18 10:16:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 10:22:51 -0700
commit9cc29a75ce8131db67b48e92dac3c16a255b92ed (patch)
tree73bf7a7483d8f7ae3872437609b6943218938ff4
parent491b2d61156333c44e6bf06e2ac0a7ac02c4d310 (diff)
Allows constructing resource variables from tf.Variable.
Also adds arguments to control distributed aggregation to the tf.Variable constructor. Removes tfe.Variable from examples as it's now unnecessary. PiperOrigin-RevId: 205096552
-rw-r--r--tensorflow/contrib/checkpoint/python/containers.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py5
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb7
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/sagan/sagan.py2
-rw-r--r--tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/tfe_test.py7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt2
-rw-r--r--tensorflow/docs_src/guide/eager.md22
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py9
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py13
-rw-r--r--tensorflow/python/ops/variable_scope.py33
-rw-r--r--tensorflow/python/ops/variables.py75
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable.pbtxt2
-rw-r--r--third_party/examples/eager/spinn/spinn.py2
17 files changed, 106 insertions, 87 deletions
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py
index 4d3d531299..242c1e8ba4 100644
--- a/tensorflow/contrib/checkpoint/python/containers.py
+++ b/tensorflow/contrib/checkpoint/python/containers.py
@@ -35,9 +35,9 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure):
self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker()
slotdeps = self.slotdeps
slots = []
- slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x"
- slots.append(slotdeps.track(tfe.Variable(4.), "y"))
- slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1"
+ slots.append(slotdeps.track(tf.Variable(3.), "x")) # Named "x"
+ slots.append(slotdeps.track(tf.Variable(4.), "y"))
+ slots.append(slotdeps.track(tf.Variable(5.), "x")) # Named "x_1"
```
"""
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
index 729d8525fa..275aee5130 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py
@@ -54,7 +54,7 @@ class Dynamics(tf.keras.Model):
self.position_fn = neural_nets.GenericNet(x_dim, factor=2.)
self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.)
- self.eps = tfe.Variable(
+ self.eps = tf.Variable(
initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
def apply_transition(self, position):
diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
index e230ad5e25..68e0bc3123 100644
--- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
+++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py
@@ -25,7 +25,6 @@ from __future__ import division
from __future__ import print_function
import tensorflow as tf
-import tensorflow.contrib.eager as tfe
class GenericNet(tf.keras.Model):
@@ -47,13 +46,13 @@ class GenericNet(tf.keras.Model):
# Scale
self.scale_layer = _custom_dense(x_dim, .001)
- self.coeff_scale = tfe.Variable(
+ self.coeff_scale = tf.Variable(
initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True)
# Translation
self.translation_layer = _custom_dense(x_dim, factor=.001)
# Transformation
self.transformation_layer = _custom_dense(x_dim, .001)
- self.coeff_transformation = tfe.Variable(
+ self.coeff_transformation = tf.Variable(
initial_value=tf.zeros([1, x_dim]),
name='coeff_transformation',
trainable=True)
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
index 591e2d0c85..5f1b48fa0d 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
@@ -118,7 +118,6 @@
"cell_type": "code",
"source": [
"import tensorflow as tf\n",
- "tfe = tf.contrib.eager # Shorthand for some symbols\n",
"\n",
"tf.enable_eager_execution()"
],
@@ -184,7 +183,7 @@
},
"cell_type": "code",
"source": [
- "v = tfe.Variable(1.0)\n",
+ "v = tf.Variable(1.0)\n",
"assert v.numpy() == 1.0\n",
"\n",
"# Re-assign the value\n",
@@ -258,8 +257,8 @@
" def __init__(self):\n",
" # Initialize variable to (5.0, 0.0)\n",
" # In practice, these should be initialized to random values.\n",
- " self.W = tfe.Variable(5.0)\n",
- " self.b = tfe.Variable(0.0)\n",
+ " self.W = tf.Variable(5.0)\n",
+ " self.b = tf.Variable(0.0)\n",
" \n",
" def __call__(self, x):\n",
" return self.W * x + self.b\n",
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index b2ac4b67c9..b0d0a5486d 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -138,7 +138,7 @@ class RevNetTest(tf.test.TestCase):
minval=0,
maxval=self.config.n_classes,
dtype=tf.int32)
- global_step = tfe.Variable(0., trainable=False)
+ global_step = tf.Variable(0., trainable=False)
model = revnet.RevNet(config=config)
model(x)
updates = model.get_updates_for(x)
diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
index c2340a293a..d64bf5354e 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py
@@ -310,7 +310,7 @@ def main(_):
with tf.device("/device:GPU:0" if have_gpu else None):
# Make learning_rate a Variable so it can be included in the checkpoint
# and we can resume training with the last saved learning_rate.
- learning_rate = tfe.Variable(20.0, name="learning_rate")
+ learning_rate = tf.Variable(20.0, name="learning_rate")
model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim,
FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout,
use_cudnn_rnn)
diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
index 561be36c91..8130414985 100644
--- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py
+++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py
@@ -62,7 +62,7 @@ class SelfAttentionModule(tf.keras.Model):
kernel_size=1,
strides=(1, 1),
data_format=data_format)
- self.scale = tfe.Variable(0., trainable=True)
+ self.scale = tf.Variable(0., trainable=True)
def call(self, x):
f = self.f(x)
diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
index 4f1410e00b..f3a65f5aab 100644
--- a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
+++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb
@@ -69,7 +69,7 @@
"cell_type": "code",
"source": [
"# Creating variables\n",
- "v = tfe.Variable(1.0)\n",
+ "v = tf.Variable(1.0)\n",
"v"
],
"execution_count": 2,
diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py
index db50b33af2..4454abfb96 100644
--- a/tensorflow/contrib/eager/python/tfe_test.py
+++ b/tensorflow/contrib/eager/python/tfe_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics
-from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer
@@ -45,12 +44,6 @@ class TFETest(test_util.TensorFlowTestCase):
r'indices = 7 is not in \[0, 3\)'):
array_ops.gather([0, 1, 2], 7)
- def testVariableError(self):
- with self.assertRaisesRegexp(
- RuntimeError,
- r'Variable not supported when eager execution is enabled'):
- variables.Variable(initial_value=1.0)
-
def testGradients(self):
def square(x):
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
index b07ee9fda9..17b79ee30c 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
@@ -51,7 +51,7 @@ For example, say we want to update 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that update would look like this:
```python
- ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
update = tf.scatter_nd_update(ref, indices, updates)
diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md
index e98206eef9..42ad9652f8 100644
--- a/tensorflow/docs_src/guide/eager.md
+++ b/tensorflow/docs_src/guide/eager.md
@@ -225,7 +225,7 @@ the tape backwards and then discard. A particular `tf.GradientTape` can only
compute one gradient; subsequent calls throw a runtime error.
```py
-w = tfe.Variable([[1.0]])
+w = tf.Variable([[1.0]])
with tf.GradientTape() as tape:
loss = w * w
@@ -260,8 +260,8 @@ def grad(weights, biases):
train_steps = 200
learning_rate = 0.01
# Start with arbitrary values for W and B on the same batch of data
-W = tfe.Variable(5.)
-B = tfe.Variable(10.)
+W = tf.Variable(5.)
+B = tf.Variable(10.)
print("Initial loss: {:.3f}".format(loss(W, B)))
@@ -407,11 +407,11 @@ with tf.device("/gpu:0"):
### Variables and optimizers
-`tfe.Variable` objects store mutable `tf.Tensor` values accessed during
+`tf.Variable` objects store mutable `tf.Tensor` values accessed during
training to make automatic differentiation easier. The parameters of a model can
be encapsulated in classes as variables.
-Better encapsulate model parameters by using `tfe.Variable` with
+Better encapsulate model parameters by using `tf.Variable` with
`tf.GradientTape`. For example, the automatic differentiation example above
can be rewritten:
@@ -419,8 +419,8 @@ can be rewritten:
class Model(tf.keras.Model):
def __init__(self):
super(Model, self).__init__()
- self.W = tfe.Variable(5., name='weight')
- self.B = tfe.Variable(10., name='bias')
+ self.W = tf.Variable(5., name='weight')
+ self.B = tf.Variable(10., name='bias')
def call(self, inputs):
return inputs * self.W + self.B
@@ -498,17 +498,17 @@ is removed, and is then deleted.
```py
with tf.device("gpu:0"):
- v = tfe.Variable(tf.random_normal([1000, 1000]))
+ v = tf.Variable(tf.random_normal([1000, 1000]))
v = None # v no longer takes up GPU memory
```
### Object-based saving
-`tfe.Checkpoint` can save and restore `tfe.Variable`s to and from
+`tfe.Checkpoint` can save and restore `tf.Variable`s to and from
checkpoints:
```py
-x = tfe.Variable(10.)
+x = tf.Variable(10.)
checkpoint = tfe.Checkpoint(x=x) # save as "x"
@@ -612,7 +612,7 @@ def line_search_step(fn, init_x, rate=1.0):
`tf.GradientTape` is a powerful interface for computing gradients, but there
is another [Autograd](https://github.com/HIPS/autograd)-style API available for
automatic differentiation. These functions are useful if writing math code with
-only tensors and gradient functions, and without `tfe.Variables`:
+only tensors and gradient functions, and without `tf.Variables`:
* `tfe.gradients_function` —Returns a function that computes the derivatives
of its input function parameter with respect to its arguments. The input
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index e358293a90..c739cd2c0d 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -246,6 +246,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[2]])
+ def testUseResource(self):
+ v = variables.Variable(1.0, use_resource=True)
+ self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
+
+ def testEagerNoUseResource(self):
+ with context.eager_mode():
+ v = variables.Variable(1.0)
+ self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
+
@test_util.run_in_graph_and_eager_modes
def testScatterMin(self):
with ops.device("cpu:0"):
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 1f56ad25bf..5979b76ff2 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -1294,3 +1294,16 @@ def is_resource_variable(var):
""""Returns True if `var` is to be considered a ResourceVariable."""
return isinstance(var, ResourceVariable) or hasattr(
var, "_should_act_as_resource_variable")
+
+
+_DEFAULT_USE_RESOURCE = False
+
+
+def _default_variable_creator(_, *args, **kwds):
+ use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE)
+ use_resource = use_resource or context.executing_eagerly()
+ if use_resource:
+ return ResourceVariable(*args, **kwds)
+ return variables.RefVariable(*args, **kwds)
+
+variables.default_variable_creator = _default_variable_creator
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 77f67c18ee..0f37dcc027 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -191,36 +191,9 @@ class _ReuseMode(enum.Enum):
# REUSE_TRUE = 3
-@tf_export("VariableSynchronization")
-class VariableSynchronization(enum.Enum):
- """Indicates when a distributed variable will be synced."""
-
- # Indicates that the synchronization will be determined by the current
- # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
- # `ON_WRITE`).
- AUTO = 0
-
- # Indicates that there will only be one copy of the variable, so there is no
- # need to sync.
- NONE = 1
-
- # Indicates that the variable will be aggregated across devices
- # every time it is updated.
- ON_WRITE = 2
-
- # Indicates that the variable will be aggregated across devices
- # when it is read (eg. when checkpointing or when evaluating an op that uses
- # the variable).
- ON_READ = 3
-
-
-@tf_export("VariableAggregation")
-class VariableAggregation(enum.Enum):
- """Indicates how a distributed variable will be aggregated."""
- NONE = 0
- SUM = 1
- MEAN = 2
-
+# TODO(apassos) remove these forwarding symbols.
+VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name
+VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 87e0de197c..6bb2d6f669 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import enum # pylint: disable=g-bad-import-order
+
import six
from tensorflow.core.framework import attr_value_pb2
@@ -38,8 +40,9 @@ from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
-def _default_variable_creator(_, *args, **kwds):
- return RefVariable(*args, **kwds)
+def default_variable_creator(_, *args, **kwds):
+ del args, kwds
+ raise NotImplementedError("resource_variable_ops needs to be imported")
def _make_getter(captured_getter, captured_previous):
@@ -49,12 +52,43 @@ def _make_getter(captured_getter, captured_previous):
return getter
+@tf_export("VariableSynchronization")
+class VariableSynchronization(enum.Enum):
+ """Indicates when a distributed variable will be synced."""
+
+ # Indicates that the synchronization will be determined by the current
+ # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ # `ON_WRITE`).
+ AUTO = 0
+
+ # Indicates that there will only be one copy of the variable, so there is no
+ # need to sync.
+ NONE = 1
+
+ # Indicates that the variable will be aggregated across devices
+ # every time it is updated.
+ ON_WRITE = 2
+
+ # Indicates that the variable will be aggregated across devices
+ # when it is read (eg. when checkpointing or when evaluating an op that uses
+ # the variable).
+ ON_READ = 3
+
+
+@tf_export("VariableAggregation")
+class VariableAggregation(enum.Enum):
+ """Indicates how a distributed variable will be aggregated."""
+ NONE = 0
+ SUM = 1
+ MEAN = 2
+
+
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
def __call__(cls, *args, **kwargs):
if cls is Variable:
- previous_getter = lambda *a, **k: _default_variable_creator(None, *a, **k)
+ previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k)
# TODO(apassos) use a stack of getters here
return previous_getter(*args, **kwargs)
else:
@@ -172,14 +206,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
* Replace `tf.Variable` with `tf.contrib.eager.Variable`;
* Call `tf.get_variable_scope().set_use_resource(True)` inside a
`tf.variable_scope` before the `tf.get_variable()` call.
-
- @compatibility(eager)
- `tf.Variable` is not compatible with eager execution. Use
- `tf.contrib.eager.Variable` instead which is compatible with both eager
- execution and graph construction. See [the TensorFlow Eager Execution
- guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
- for details on how variables work in eager execution.
- @end_compatibility
"""
def __init__(self,
@@ -193,7 +219,10 @@ class Variable(six.with_metaclass(VariableMetaclass,
dtype=None,
expected_shape=None,
import_scope=None,
- constraint=None):
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
The new variable is added to the graph collections listed in `collections`,
@@ -245,20 +274,24 @@ class Variable(six.with_metaclass(VariableMetaclass,
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ use_resource: if True, a ResourceVariable is created; otherwise an
+ old-style ref-based variable is created. When eager execution is enabled
+ a resource variable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Raises:
ValueError: If both `variable_def` and initial_value are specified.
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
RuntimeError: If eager execution is enabled.
-
- @compatibility(eager)
- `tf.Variable` is not compatible with eager execution. Use
- `tfe.Variable` instead which is compatible with both eager execution
- and graph construction. See [the TensorFlow Eager Execution
- guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
- for details on how variables work in eager execution.
- @end_compatibility
"""
raise NotImplementedError
@@ -1714,7 +1747,7 @@ class PartitionedVariable(object):
"""A container for partitioned `Variable` objects.
@compatibility(eager) `tf.PartitionedVariable` is not compatible with
- eager execution. Use `tfe.Variable` instead which is compatible
+ eager execution. Use `tf.Variable` instead which is compatible
with both eager execution and graph construction. See [the
TensorFlow Eager Execution
guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers)
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
index 23b552cc38..e841c4ad89 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
@@ -49,7 +49,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "assign"
diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py
index 67456a5bdf..c242ef3fdd 100644
--- a/third_party/examples/eager/spinn/spinn.py
+++ b/third_party/examples/eager/spinn/spinn.py
@@ -419,7 +419,7 @@ class SNLIClassifierTrainer(tfe.Checkpointable):
# Create a custom learning rate Variable for the RMSProp optimizer, because
# the learning rate needs to be manually decayed later (see
# decay_learning_rate()).
- self._learning_rate = tfe.Variable(lr, name="learning_rate")
+ self._learning_rate = tf.Variable(lr, name="learning_rate")
self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate,
epsilon=1e-6)