diff options
author | Alexandre Passos <apassos@google.com> | 2018-07-18 10:16:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 10:22:51 -0700 |
commit | 9cc29a75ce8131db67b48e92dac3c16a255b92ed (patch) | |
tree | 73bf7a7483d8f7ae3872437609b6943218938ff4 | |
parent | 491b2d61156333c44e6bf06e2ac0a7ac02c4d310 (diff) |
Allows constructing resource variables from tf.Variable.
Also adds arguments to control distributed aggregation to the tf.Variable constructor.
Removes tfe.Variable from examples as it's now unnecessary.
PiperOrigin-RevId: 205096552
17 files changed, 106 insertions, 87 deletions
diff --git a/tensorflow/contrib/checkpoint/python/containers.py b/tensorflow/contrib/checkpoint/python/containers.py index 4d3d531299..242c1e8ba4 100644 --- a/tensorflow/contrib/checkpoint/python/containers.py +++ b/tensorflow/contrib/checkpoint/python/containers.py @@ -35,9 +35,9 @@ class UniqueNameTracker(data_structures.CheckpointableDataStructure): self.slotdeps = tf.contrib.checkpoint.UniqueNameTracker() slotdeps = self.slotdeps slots = [] - slots.append(slotdeps.track(tfe.Variable(3.), "x")) # Named "x" - slots.append(slotdeps.track(tfe.Variable(4.), "y")) - slots.append(slotdeps.track(tfe.Variable(5.), "x")) # Named "x_1" + slots.append(slotdeps.track(tf.Variable(3.), "x")) # Named "x" + slots.append(slotdeps.track(tf.Variable(4.), "y")) + slots.append(slotdeps.track(tf.Variable(5.), "x")) # Named "x_1" ``` """ diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py index 729d8525fa..275aee5130 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/l2hmc.py @@ -54,7 +54,7 @@ class Dynamics(tf.keras.Model): self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) - self.eps = tfe.Variable( + self.eps = tf.Variable( initial_value=eps, name="eps", dtype=tf.float32, trainable=True) def apply_transition(self, position): diff --git a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py index e230ad5e25..68e0bc3123 100644 --- a/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py +++ b/tensorflow/contrib/eager/python/examples/l2hmc/neural_nets.py @@ -25,7 +25,6 @@ from __future__ import division from __future__ import print_function import tensorflow as tf -import tensorflow.contrib.eager as tfe class GenericNet(tf.keras.Model): @@ -47,13 +46,13 @@ class GenericNet(tf.keras.Model): # Scale self.scale_layer = _custom_dense(x_dim, .001) - self.coeff_scale = tfe.Variable( + self.coeff_scale = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_scale', trainable=True) # Translation self.translation_layer = _custom_dense(x_dim, factor=.001) # Transformation self.transformation_layer = _custom_dense(x_dim, .001) - self.coeff_transformation = tfe.Variable( + self.coeff_transformation = tf.Variable( initial_value=tf.zeros([1, x_dim]), name='coeff_transformation', trainable=True) diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb index 591e2d0c85..5f1b48fa0d 100644 --- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb +++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb @@ -118,7 +118,6 @@ "cell_type": "code", "source": [ "import tensorflow as tf\n", - "tfe = tf.contrib.eager # Shorthand for some symbols\n", "\n", "tf.enable_eager_execution()" ], @@ -184,7 +183,7 @@ }, "cell_type": "code", "source": [ - "v = tfe.Variable(1.0)\n", + "v = tf.Variable(1.0)\n", "assert v.numpy() == 1.0\n", "\n", "# Re-assign the value\n", @@ -258,8 +257,8 @@ " def __init__(self):\n", " # Initialize variable to (5.0, 0.0)\n", " # In practice, these should be initialized to random values.\n", - " self.W = tfe.Variable(5.0)\n", - " self.b = tfe.Variable(0.0)\n", + " self.W = tf.Variable(5.0)\n", + " self.b = tf.Variable(0.0)\n", " \n", " def __call__(self, x):\n", " return self.W * x + self.b\n", diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index b2ac4b67c9..b0d0a5486d 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -138,7 +138,7 @@ class RevNetTest(tf.test.TestCase): minval=0, maxval=self.config.n_classes, dtype=tf.int32) - global_step = tfe.Variable(0., trainable=False) + global_step = tf.Variable(0., trainable=False) model = revnet.RevNet(config=config) model(x) updates = model.get_updates_for(x) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index c2340a293a..d64bf5354e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -310,7 +310,7 @@ def main(_): with tf.device("/device:GPU:0" if have_gpu else None): # Make learning_rate a Variable so it can be included in the checkpoint # and we can resume training with the last saved learning_rate. - learning_rate = tfe.Variable(20.0, name="learning_rate") + learning_rate = tf.Variable(20.0, name="learning_rate") model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, use_cudnn_rnn) diff --git a/tensorflow/contrib/eager/python/examples/sagan/sagan.py b/tensorflow/contrib/eager/python/examples/sagan/sagan.py index 561be36c91..8130414985 100644 --- a/tensorflow/contrib/eager/python/examples/sagan/sagan.py +++ b/tensorflow/contrib/eager/python/examples/sagan/sagan.py @@ -62,7 +62,7 @@ class SelfAttentionModule(tf.keras.Model): kernel_size=1, strides=(1, 1), data_format=data_format) - self.scale = tfe.Variable(0., trainable=True) + self.scale = tf.Variable(0., trainable=True) def call(self, x): f = self.f(x) diff --git a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb index 4f1410e00b..f3a65f5aab 100644 --- a/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb +++ b/tensorflow/contrib/eager/python/examples/workshop/2_models.ipynb @@ -69,7 +69,7 @@ "cell_type": "code", "source": [ "# Creating variables\n", - "v = tfe.Variable(1.0)\n", + "v = tf.Variable(1.0)\n", "v" ], "execution_count": 2, diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index db50b33af2..4454abfb96 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -27,7 +27,6 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import numerics -from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer @@ -45,12 +44,6 @@ class TFETest(test_util.TensorFlowTestCase): r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) - def testVariableError(self): - with self.assertRaisesRegexp( - RuntimeError, - r'Variable not supported when eager execution is enabled'): - variables.Variable(initial_value=1.0) - def testGradients(self): def square(x): diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt index b07ee9fda9..17b79ee30c 100644 --- a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt @@ -51,7 +51,7 @@ For example, say we want to update 4 scattered elements to a rank-1 tensor to 8 elements. In Python, that update would look like this: ```python - ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8]) + ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) indices = tf.constant([[4], [3], [1] ,[7]]) updates = tf.constant([9, 10, 11, 12]) update = tf.scatter_nd_update(ref, indices, updates) diff --git a/tensorflow/docs_src/guide/eager.md b/tensorflow/docs_src/guide/eager.md index e98206eef9..42ad9652f8 100644 --- a/tensorflow/docs_src/guide/eager.md +++ b/tensorflow/docs_src/guide/eager.md @@ -225,7 +225,7 @@ the tape backwards and then discard. A particular `tf.GradientTape` can only compute one gradient; subsequent calls throw a runtime error. ```py -w = tfe.Variable([[1.0]]) +w = tf.Variable([[1.0]]) with tf.GradientTape() as tape: loss = w * w @@ -260,8 +260,8 @@ def grad(weights, biases): train_steps = 200 learning_rate = 0.01 # Start with arbitrary values for W and B on the same batch of data -W = tfe.Variable(5.) -B = tfe.Variable(10.) +W = tf.Variable(5.) +B = tf.Variable(10.) print("Initial loss: {:.3f}".format(loss(W, B))) @@ -407,11 +407,11 @@ with tf.device("/gpu:0"): ### Variables and optimizers -`tfe.Variable` objects store mutable `tf.Tensor` values accessed during +`tf.Variable` objects store mutable `tf.Tensor` values accessed during training to make automatic differentiation easier. The parameters of a model can be encapsulated in classes as variables. -Better encapsulate model parameters by using `tfe.Variable` with +Better encapsulate model parameters by using `tf.Variable` with `tf.GradientTape`. For example, the automatic differentiation example above can be rewritten: @@ -419,8 +419,8 @@ can be rewritten: class Model(tf.keras.Model): def __init__(self): super(Model, self).__init__() - self.W = tfe.Variable(5., name='weight') - self.B = tfe.Variable(10., name='bias') + self.W = tf.Variable(5., name='weight') + self.B = tf.Variable(10., name='bias') def call(self, inputs): return inputs * self.W + self.B @@ -498,17 +498,17 @@ is removed, and is then deleted. ```py with tf.device("gpu:0"): - v = tfe.Variable(tf.random_normal([1000, 1000])) + v = tf.Variable(tf.random_normal([1000, 1000])) v = None # v no longer takes up GPU memory ``` ### Object-based saving -`tfe.Checkpoint` can save and restore `tfe.Variable`s to and from +`tfe.Checkpoint` can save and restore `tf.Variable`s to and from checkpoints: ```py -x = tfe.Variable(10.) +x = tf.Variable(10.) checkpoint = tfe.Checkpoint(x=x) # save as "x" @@ -612,7 +612,7 @@ def line_search_step(fn, init_x, rate=1.0): `tf.GradientTape` is a powerful interface for computing gradients, but there is another [Autograd](https://github.com/HIPS/autograd)-style API available for automatic differentiation. These functions are useful if writing math code with -only tensors and gradient functions, and without `tfe.Variables`: +only tensors and gradient functions, and without `tf.Variables`: * `tfe.gradients_function` —Returns a function that computes the derivatives of its input function parameter with respect to its arguments. The input diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index e358293a90..c739cd2c0d 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -246,6 +246,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) self.assertEqual(self.evaluate(read), [[2]]) + def testUseResource(self): + v = variables.Variable(1.0, use_resource=True) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + + def testEagerNoUseResource(self): + with context.eager_mode(): + v = variables.Variable(1.0) + self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable)) + @test_util.run_in_graph_and_eager_modes def testScatterMin(self): with ops.device("cpu:0"): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 1f56ad25bf..5979b76ff2 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1294,3 +1294,16 @@ def is_resource_variable(var): """"Returns True if `var` is to be considered a ResourceVariable.""" return isinstance(var, ResourceVariable) or hasattr( var, "_should_act_as_resource_variable") + + +_DEFAULT_USE_RESOURCE = False + + +def _default_variable_creator(_, *args, **kwds): + use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE) + use_resource = use_resource or context.executing_eagerly() + if use_resource: + return ResourceVariable(*args, **kwds) + return variables.RefVariable(*args, **kwds) + +variables.default_variable_creator = _default_variable_creator diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 77f67c18ee..0f37dcc027 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -191,36 +191,9 @@ class _ReuseMode(enum.Enum): # REUSE_TRUE = 3 -@tf_export("VariableSynchronization") -class VariableSynchronization(enum.Enum): - """Indicates when a distributed variable will be synced.""" - - # Indicates that the synchronization will be determined by the current - # `DistributionStrategy` (eg. With `MirroredStrategy` this would be - # `ON_WRITE`). - AUTO = 0 - - # Indicates that there will only be one copy of the variable, so there is no - # need to sync. - NONE = 1 - - # Indicates that the variable will be aggregated across devices - # every time it is updated. - ON_WRITE = 2 - - # Indicates that the variable will be aggregated across devices - # when it is read (eg. when checkpointing or when evaluating an op that uses - # the variable). - ON_READ = 3 - - -@tf_export("VariableAggregation") -class VariableAggregation(enum.Enum): - """Indicates how a distributed variable will be aggregated.""" - NONE = 0 - SUM = 1 - MEAN = 2 - +# TODO(apassos) remove these forwarding symbols. +VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name +VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name AUTO_REUSE = _ReuseMode.AUTO_REUSE tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE") diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 87e0de197c..6bb2d6f669 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum # pylint: disable=g-bad-import-order + import six from tensorflow.core.framework import attr_value_pb2 @@ -38,8 +40,9 @@ from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export -def _default_variable_creator(_, *args, **kwds): - return RefVariable(*args, **kwds) +def default_variable_creator(_, *args, **kwds): + del args, kwds + raise NotImplementedError("resource_variable_ops needs to be imported") def _make_getter(captured_getter, captured_previous): @@ -49,12 +52,43 @@ def _make_getter(captured_getter, captured_previous): return getter +@tf_export("VariableSynchronization") +class VariableSynchronization(enum.Enum): + """Indicates when a distributed variable will be synced.""" + + # Indicates that the synchronization will be determined by the current + # `DistributionStrategy` (eg. With `MirroredStrategy` this would be + # `ON_WRITE`). + AUTO = 0 + + # Indicates that there will only be one copy of the variable, so there is no + # need to sync. + NONE = 1 + + # Indicates that the variable will be aggregated across devices + # every time it is updated. + ON_WRITE = 2 + + # Indicates that the variable will be aggregated across devices + # when it is read (eg. when checkpointing or when evaluating an op that uses + # the variable). + ON_READ = 3 + + +@tf_export("VariableAggregation") +class VariableAggregation(enum.Enum): + """Indicates how a distributed variable will be aggregated.""" + NONE = 0 + SUM = 1 + MEAN = 2 + + class VariableMetaclass(type): """Metaclass to allow construction of tf.Variable to be overridden.""" def __call__(cls, *args, **kwargs): if cls is Variable: - previous_getter = lambda *a, **k: _default_variable_creator(None, *a, **k) + previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k) # TODO(apassos) use a stack of getters here return previous_getter(*args, **kwargs) else: @@ -172,14 +206,6 @@ class Variable(six.with_metaclass(VariableMetaclass, * Replace `tf.Variable` with `tf.contrib.eager.Variable`; * Call `tf.get_variable_scope().set_use_resource(True)` inside a `tf.variable_scope` before the `tf.get_variable()` call. - - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tf.contrib.eager.Variable` instead which is compatible with both eager - execution and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility """ def __init__(self, @@ -193,7 +219,10 @@ class Variable(six.with_metaclass(VariableMetaclass, dtype=None, expected_shape=None, import_scope=None, - constraint=None): + constraint=None, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): """Creates a new variable with value `initial_value`. The new variable is added to the graph collections listed in `collections`, @@ -245,20 +274,24 @@ class Variable(six.with_metaclass(VariableMetaclass, variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. + use_resource: if True, a ResourceVariable is created; otherwise an + old-style ref-based variable is created. When eager execution is enabled + a resource variable is always created. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. Raises: ValueError: If both `variable_def` and initial_value are specified. ValueError: If the initial value is not specified, or does not have a shape and `validate_shape` is `True`. RuntimeError: If eager execution is enabled. - - @compatibility(eager) - `tf.Variable` is not compatible with eager execution. Use - `tfe.Variable` instead which is compatible with both eager execution - and graph construction. See [the TensorFlow Eager Execution - guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) - for details on how variables work in eager execution. - @end_compatibility """ raise NotImplementedError @@ -1714,7 +1747,7 @@ class PartitionedVariable(object): """A container for partitioned `Variable` objects. @compatibility(eager) `tf.PartitionedVariable` is not compatible with - eager execution. Use `tfe.Variable` instead which is compatible + eager execution. Use `tf.Variable` instead which is compatible with both eager execution and graph construction. See [the TensorFlow Eager Execution guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/g3doc/guide.md#variables-and-optimizers) diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt index 23b552cc38..e841c4ad89 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -49,7 +49,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], " } member_method { name: "assign" diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index 67456a5bdf..c242ef3fdd 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -419,7 +419,7 @@ class SNLIClassifierTrainer(tfe.Checkpointable): # Create a custom learning rate Variable for the RMSProp optimizer, because # the learning rate needs to be manually decayed later (see # decay_learning_rate()). - self._learning_rate = tfe.Variable(lr, name="learning_rate") + self._learning_rate = tf.Variable(lr, name="learning_rate") self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate, epsilon=1e-6) |