aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-26 11:24:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 11:29:11 -0700
commit6b6976e3ba19484f893092712e4577daeb92ad3b (patch)
treec939245e41f3e03a8afbbf84c1ff60cbeb95abfd /tensorflow
parenta0af3551a83ba81ddfd2b43cca75edff4c0fcdc1 (diff)
Deprecate tfe.Network and associated utilities in favor of tf.keras.Model.
Also throws an error rather than silently saving incorrectly with tf.train.Checkpoint. (In response to confusion over tf.train.Checkpoint with tfe.Network) PiperOrigin-RevId: 194426679
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/eager/python/network.py49
-rw-r--r--tensorflow/contrib/eager/python/network_test.py7
2 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index 2f8721324f..44828bea50 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -28,9 +28,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer
from tensorflow.python.layers import base
from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
+from tensorflow.python.util import deprecation
# pylint: disable=protected-access
# Explanation for protected-access disable: Network has lots of same-class and
@@ -52,9 +54,40 @@ def _network_name_scope_naming(current_variable_scope):
return current_variable_scope.name + "/"
+_NETWORK_DEPRECATION_MESSAGE = (
+ "Please inherit from `tf.keras.Model`, and see its documentation for "
+ "details. `tf.keras.Model` should be a drop-in replacement for "
+ "`tfe.Network` in most cases, but note that `track_layer` is no longer "
+ "necessary or supported. Instead, `Layer` instances are tracked on "
+ "attribute assignment (see the section of `tf.keras.Model`'s documentation "
+ "on subclassing). Since the output of `track_layer` is often assigned to "
+ "an attribute anyway, most code can be ported by simply removing the "
+ "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow "
+ "`Layer` instances, including those from `tf.layers`, but switching to "
+ "the `tf.keras.layers` versions along with the migration to "
+ "`tf.keras.Model` is recommended, since it will preserve variable names. "
+ "Feel free to import it with an alias to avoid excess typing :)."
+)
+
+
class Network(base.Layer):
"""Represents the composition of a set of Layers.
+ *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation
+ for details. `tf.keras.Model` should be a drop-in replacement for
+ `tfe.Network` in most cases, but note that `track_layer` is no longer
+ necessary or supported. Instead, `Layer` instances are tracked on attribute
+ assignment (see the section of `tf.keras.Model`'s documentation on
+ subclassing). Since the output of `track_layer` is often assigned to an
+ attribute anyway, most code can be ported by simply removing the `track_layer`
+ calls.
+
+ `tf.keras.Model` works with all TensorFlow `Layer` instances, including those
+ from `tf.layers`, but switching to the `tf.keras.layers` versions along with
+ the migration to `tf.keras.Model` is recommended, since it will preserve
+ variable names. Feel free to import it with an alias to avoid excess typing
+ :).
+
`Network` implements the `Layer` interface and adds convenience methods for
managing sub-`Layer`s, such as listing variables.
@@ -112,6 +145,7 @@ class Network(base.Layer):
# - Detect layers used in __call__ that weren't registered with track_layer.
# - Convert inputs to __call__ to tensors.
+ @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE)
def __init__(self, name=None):
"""Configure the `Network`.
@@ -130,6 +164,10 @@ class Network(base.Layer):
ValueError: If `name` is not valid. Note that some naming errors will
instead be raised when the `Network` is called.
"""
+ if context.executing_eagerly():
+ logging.warning(
+ ("** tfe.Network is deprecated and will be removed in a future "
+ "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE)
if isinstance(name, variable_scope.VariableScope):
raise ValueError("VariableScopes are not valid Network names.")
if name is not None and "/" in name:
@@ -152,6 +190,11 @@ class Network(base.Layer):
self._variable_scope_counts_on_init = (
variable_scope.get_variable_scope_store().variable_scopes_count)
+ def _gather_saveables_for_checkpoint(self):
+ raise NotImplementedError(
+ "tfe.Network does not support object-based checkpointing.\n\n%s"
+ % _NETWORK_DEPRECATION_MESSAGE)
+
def _name_scope_name(self, current_variable_scope):
"""Overrides Layer op naming to match variable naming."""
return _network_name_scope_naming(
@@ -706,6 +749,9 @@ def _make_prefix_stripping_map_fn(scope_name):
return _strip_variable_prefix
+@deprecation.deprecated(date=None, instructions=(
+ "Please inherit from tf.keras.Model instead of tfe.Network, and use "
+ "tf.keras.Model.save_weights."))
def save_network_checkpoint(
network, save_path, global_step=None, map_func=None):
"""Save variables from the Network to a checkpoint.
@@ -905,6 +951,9 @@ def _set_restore_on_create(network, save_path, map_func, user_map_func,
_add_deferred_restoration(network, deferred_restoration)
+@deprecation.deprecated(date=None, instructions=(
+ "Please inherit from tf.keras.Model instead of tfe.Network, and use "
+ "tf.keras.Model.load_weights."))
def restore_network_checkpoint(network, save_path, map_func=None):
"""Restore the Network from a checkpoint.
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index f43376d5d7..6a51d03de5 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpointable_utils
from tensorflow.python.training import training_util
@@ -62,6 +63,12 @@ class RegularizedNetwork(network.Network):
class NetworkTest(test.TestCase):
+ def test_checkpointing_not_implemented(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint = checkpointable_utils.Checkpoint(net=MyNetwork())
+ with self.assertRaises(NotImplementedError):
+ checkpoint.save(checkpoint_directory)
+
def _save_modify_load_network_built(self, net, global_step=None):
checkpoint_directory = self.get_temp_dir()
checkpoint_path = network.save_network_checkpoint(