diff options
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/core/framework/variable.proto | 3 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 2 | ||||
-rw-r--r-- | tensorflow/python/eager/graph_callable.py | 2 | ||||
-rw-r--r-- | tensorflow/python/eager/pywrap_tfe_src.cc | 4 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/network.py | 52 | ||||
-rw-r--r-- | tensorflow/python/keras/model_subclassing_test.py | 45 | ||||
-rw-r--r-- | tensorflow/python/keras/utils/layer_utils.py | 55 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/resource_variable_ops_test.py | 19 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/variables_test.py | 17 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 8 | ||||
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/variables.py | 7 | ||||
-rw-r--r-- | tensorflow/python/training/checkpointable/data_structures.py | 36 | ||||
-rw-r--r-- | tensorflow/python/training/checkpointable/data_structures_test.py | 19 | ||||
-rw-r--r-- | tensorflow/tools/api/golden/tensorflow.-variable.pbtxt | 4 |
15 files changed, 233 insertions, 46 deletions
diff --git a/tensorflow/core/framework/variable.proto b/tensorflow/core/framework/variable.proto index 93ae423bab..66ba4cba7d 100644 --- a/tensorflow/core/framework/variable.proto +++ b/tensorflow/core/framework/variable.proto @@ -26,6 +26,9 @@ message VariableDef { // Whether to represent this as a ResourceVariable. bool is_resource = 5; + + // Whether this variable should be trained. + bool trainable = 7; } message SaveSliceInfoDef { diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 23d87fb394..559063d6ae 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -494,7 +494,7 @@ class GraphModeFunction(object): def __call__(self, *args): """Executes the passed function in eager mode.""" for v in self._variables: - if v._trainable: # pylint: disable=protected-access + if v.trainable: tape.watch_variable(v) tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)] diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index d9ffcbd203..760a148552 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -202,7 +202,7 @@ class _InitializingFunctionObject(object): v.handle).numpy() for v in self._call_fn.variables] if all(x for x in initialized): for v in self._call_fn.variables: - if v._trainable: # pylint: disable=protected-access + if v.trainable: tape.watch_variable(v) return self._call_fn(*args) elif all(not x for x in initialized): diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 52b90504f3..e3ce0ef9d0 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -1874,10 +1874,10 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, void MaybeWatchVariable(PyObject* input) { DCHECK(CheckResourceVariable(input)); - DCHECK(PyObject_HasAttrString(input, "_trainable")); + DCHECK(PyObject_HasAttrString(input, "trainable")); tensorflow::Safe_PyObjectPtr trainable( - PyObject_GetAttrString(input, "_trainable")); + PyObject_GetAttrString(input, "trainable")); if (trainable.get() == Py_False) return; TFE_Py_TapeSetWatchVariable(input); } diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 6db41472b6..f63ca1a207 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -36,9 +36,10 @@ from tensorflow.python.keras import backend from tensorflow.python.keras.engine import base_layer from tensorflow.python.keras.engine import saving from tensorflow.python.keras.utils import generic_utils +from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite -from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.training.checkpointable import data_structures_base @@ -94,6 +95,11 @@ class Network(base_layer.Layer): self.trainable = True self._is_compiled = False self._expects_training_arg = False + # A list of "extra" variables assigned to attributes of this class, included + # in self.weights and self.variables. Always empty for graph networks (but + # included in base_init to avoid excessive special casing when retrieving + # the value). + self._extra_variables = [] self.supports_masking = False if not hasattr(self, 'optimizer'): @@ -347,11 +353,22 @@ class Network(base_layer.Layer): # layers). Therefore Model tracks Checkpointable objects itself. self._track_checkpointable( checkpointable=value, name=name, overwrite=True) + if ( # For subclassed models only, users may add extra weights/variables + # simply by assigning them to attributes. + not self._is_graph_network + and isinstance(value, variables.Variable)): + self._extra_variables.append(value) super(Network, self).__setattr__(name, value) def add_variable(self, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, constraint=None): - raise NotImplementedError('`add_variable` is not supported on Networks.') + if self._is_graph_network: + raise NotImplementedError('`add_variable` is not supported on Networks.') + else: + raise NotImplementedError( + '`add_variable` is not supported on Networks. However, you may ' + 'assign variables to attributes and they will show up in the weights ' + 'and variables properties.') def add_loss(self, *args, **kwargs): if context.executing_eagerly(): @@ -589,24 +606,17 @@ class Network(base_layer.Layer): @property def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights + return layer_utils.gather_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights + return layer_utils.gather_non_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def input_spec(self): @@ -1437,10 +1447,10 @@ class Network(base_layer.Layer): 'have not yet been created, so no summary can be ' 'displayed. Build the model first ' '(e.g. by calling it on some data).') - print_layer_summary(self, - line_length=line_length, - positions=positions, - print_fn=print_fn) + layer_utils.print_summary(self, + line_length=line_length, + positions=positions, + print_fn=print_fn) def get_source_inputs(tensor, layer=None, node_index=None): diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index 558854ab97..86f7e20bec 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -622,6 +622,51 @@ class ModelSubclassingTest(test.TestCase): self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref) self.assertEqual('notdep_var:0', m.notdep_var.name) + def test_extra_variable(self): + + class ExtraVar(keras.Model): + + def __init__(self): + super(ExtraVar, self).__init__() + self.dense = keras.layers.Dense(1) + self.var = resource_variable_ops.ResourceVariable(1.) + self.not_trainable_var = resource_variable_ops.ResourceVariable( + 2., trainable=False) + + def call(self, inputs): + return self.dense(inputs + self.var) + + m = ExtraVar() + self.assertTrue(m.trainable) + self.assertEqual([m.dense], m.layers) + self.assertEqual([m.var, m.not_trainable_var], m.variables) + self.assertEqual([m.var], m.trainable_variables) + self.assertEqual([m.not_trainable_var], m.non_trainable_variables) + m.trainable = False + self.assertEqual([m.var, m.not_trainable_var], m.variables) + self.assertEqual([], m.trainable_variables) + self.assertEqual([m.var, m.not_trainable_var], m.non_trainable_variables) + m.trainable = True + + m(array_ops.ones([1, 1])) + + self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.variables) + self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.weights) + + self.assertEqual([m.dense.kernel, m.dense.bias, m.var, m.not_trainable_var], + m.variables) + self.assertEqual([m.dense.kernel, m.dense.bias, m.var], + m.trainable_variables) + self.assertEqual([m.not_trainable_var], m.non_trainable_variables) + + m.dense.trainable = False + self.assertEqual( + [m.var, m.dense.kernel, m.dense.bias, m.not_trainable_var], + m.variables) + self.assertEqual([m.var], m.trainable_variables) + self.assertEqual([m.dense.kernel, m.dense.bias, m.not_trainable_var], + m.non_trainable_variables) + class CustomCallModel(keras.Model): diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py index bd61f8e9cc..88daff0461 100644 --- a/tensorflow/python/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/utils/layer_utils.py @@ -201,6 +201,61 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): print_fn('_' * line_length) +def gather_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected trainable weights/variables. + """ + if not trainable: + return [] + weights = [] + for layer in sub_layers: + weights += layer.trainable_weights + trainable_extra_variables = [ + v for v in extra_variables if v.trainable] + return weights + trainable_extra_variables + + +def gather_non_trainable_weights(trainable, sub_layers, extra_variables): + """Lists the non-trainable weights for an object with sub-layers. + + Args: + trainable: Whether the object collecting the variables is trainable. + sub_layers: A flat list of Layer objects owned by this object, to collect + variables from. + extra_variables: Any extra variables to include. Their `.trainable` property + is used to categorize them. + + Returns: + A list of collected non-trainable weights/variables. + """ + trainable_extra_variables = [] + non_trainable_extra_variables = [] + for v in extra_variables: + if v.trainable: + trainable_extra_variables.append(v) + else: + non_trainable_extra_variables.append(v) + weights = [] + for layer in sub_layers: + weights += layer.non_trainable_weights + if not trainable: + trainable_weights = [] + for layer in sub_layers: + trainable_weights += layer.trainable_weights + return (trainable_weights + trainable_extra_variables + + weights + non_trainable_extra_variables) + return weights + non_trainable_extra_variables + + @tf_export('keras.utils.convert_all_kernels_in_model') def convert_all_kernels_in_model(model): """Converts all convolution kernels in a model from Theano to TensorFlow. diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 972fbdb3d6..00d517e64e 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -538,6 +538,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): sess.run(v.initialized_value()) + def testTrainableInProto(self): + with ops.Graph().as_default(): + non_trainable_variable = resource_variable_ops.ResourceVariable( + trainable=False, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + False, + resource_variable_ops.ResourceVariable( + variable_def=non_trainable_variable.to_proto()) + .trainable) + trainable_variable = resource_variable_ops.ResourceVariable( + trainable=True, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + True, + resource_variable_ops.ResourceVariable( + variable_def=trainable_variable.to_proto()) + .trainable) + @test_util.run_in_graph_and_eager_modes() def testSparseRead(self): with self.test_session(): diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py index 27599868b7..62d596da91 100644 --- a/tensorflow/python/kernel_tests/variables_test.py +++ b/tensorflow/python/kernel_tests/variables_test.py @@ -496,6 +496,23 @@ class VariablesTestCase(test.TestCase): with self.assertRaises(ValueError): sess.run(v.initialized_value()) + def testTrainableInProto(self): + with ops.Graph().as_default(): + non_trainable_variable = variables.Variable( + trainable=False, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + False, + variables.Variable(variable_def=non_trainable_variable.to_proto()) + .trainable) + trainable_variable = variables.Variable( + trainable=True, + initial_value=constant_op.constant(10.0)) + self.assertEqual( + True, + variables.Variable(variable_def=trainable_variable.to_proto()) + .trainable) + def testLoad(self): with self.test_session(): var = variables.Variable(np.zeros((5, 5), np.float32)) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index e37e93ea35..7061b32808 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable): import_scope=import_scope)) else: self._initial_value = None + self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: snapshot = g.as_graph_element( ops.prepend_name_scope( @@ -735,7 +736,7 @@ class ResourceVariable(variables.Variable): return self._save_slice_info def _read_variable_op(self): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return gen_resource_variable_ops.read_variable_op(self._handle, self._dtype) @@ -760,7 +761,7 @@ class ResourceVariable(variables.Variable): def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: - if self._trainable: + if self.trainable: tape.watch_variable(self) value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) @@ -801,6 +802,7 @@ class ResourceVariable(variables.Variable): var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, export_scope) var_def.is_resource = True + var_def.trainable = self.trainable if self._save_slice_info: var_def.save_slice_info_def.MergeFrom( self._save_slice_info.to_proto(export_scope=export_scope)) @@ -913,7 +915,7 @@ class ResourceVariable(variables.Variable): return assign_add_op def _lazy_read(self, op): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return _UnreadVariable( self._handle, self.dtype, self._shape, self._in_graph_mode, diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 8d93d24b14..fa34774622 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1261,13 +1261,13 @@ class EagerVariableStore(object): def trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if x._trainable], + return sorted([x for x in self._store._vars.values() if x.trainable], key=lambda x: x.name) # pylint: enable=protected-access def non_trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if not x._trainable], + return sorted([x for x in self._store._vars.values() if not x.trainable], key=lambda x: x.name) # pylint: enable=protected-access @@ -1296,7 +1296,7 @@ class EagerVariableStore(object): new_var = resource_variable_ops.ResourceVariable( var.read_value(), name=stripped_var_name, - trainable=var._trainable) + trainable=var.trainable) new_store._store._vars[key] = new_var return new_store # pylint: enable=protected-access diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index d88fd836f5..4be9f5eb68 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -341,6 +341,7 @@ class Variable(checkpointable.CheckpointableBase): self._update_uid = initial_value.checkpoint_position.restore_uid initial_value = initial_value.wrapped_value + self._trainable = trainable if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): @@ -450,6 +451,7 @@ class Variable(checkpointable.CheckpointableBase): import_scope=import_scope)) else: self._initial_value = None + self._trainable = getattr(variable_def, "trainable", True) self._snapshot = g.as_graph_element( ops.prepend_name_scope(variable_def.snapshot_name, import_scope=import_scope)) @@ -543,6 +545,10 @@ class Variable(checkpointable.CheckpointableBase): self._ref().set_shape(shape) self.value().set_shape(shape) + @property + def trainable(self): + return self._trainable + def eval(self, session=None): """In a session, computes and returns the value of this variable. @@ -1050,6 +1056,7 @@ class Variable(checkpointable.CheckpointableBase): # For backwards compatibility. var_def.initial_value_name = ops.strip_name_scope( self._initial_value.name, export_scope) + var_def.trainable = self.trainable var_def.initializer_name = ops.strip_name_scope( self.initializer.name, export_scope) var_def.snapshot_name = ops.strip_name_scope( diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py index 62cefa4f20..69ed253fb2 100644 --- a/tensorflow/python/training/checkpointable/data_structures.py +++ b/tensorflow/python/training/checkpointable/data_structures.py @@ -22,6 +22,8 @@ import collections import six from tensorflow.python.keras.engine import base_layer +from tensorflow.python.keras.utils import layer_utils +from tensorflow.python.ops import variables from tensorflow.python.training.checkpointable import base as checkpointable_lib from tensorflow.python.training.checkpointable import data_structures_base @@ -41,11 +43,14 @@ class CheckpointableDataStructure( def __init__(self): self._layers = [] self.trainable = True + self._extra_variables = [] def _track_value(self, value, name): """Add a dependency on `value`.""" if isinstance(value, checkpointable_lib.CheckpointableBase): self._track_checkpointable(value, name=name) + if isinstance(value, variables.Variable): + self._extra_variables.append(value) else: raise ValueError( ("Only checkpointable objects (such as Layers or Optimizers) may be " @@ -67,30 +72,31 @@ class CheckpointableDataStructure( @property def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights + return layer_utils.gather_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights + return layer_utils.gather_non_trainable_weights( + trainable=self.trainable, + sub_layers=self.layers, + extra_variables=self._extra_variables) @property def weights(self): return self.trainable_weights + self.non_trainable_weights @property + def trainable_variables(self): + return self.trainable_weights + + @property + def non_trainable_variables(self): + return self.non_trainable_weights + + @property def variables(self): return self.weights diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py index 31a0e8b622..b05b3a8800 100644 --- a/tensorflow/python/training/checkpointable/data_structures_test.py +++ b/tensorflow/python/training/checkpointable/data_structures_test.py @@ -139,6 +139,25 @@ class ListTests(test.TestCase): outer.variables[0], resource_variable_ops.ResourceVariable) + def testNonLayerVariables(self): + v = resource_variable_ops.ResourceVariable([1.]) + l = data_structures.List([v]) + self.assertTrue(l.trainable) + self.assertEqual([], l.layers) + self.assertEqual([v], l.variables) + self.assertEqual([v], l.trainable_weights) + self.assertEqual([], l.non_trainable_variables) + l.trainable = False + self.assertEqual([v], l.variables) + self.assertEqual([], l.trainable_variables) + self.assertEqual([v], l.non_trainable_variables) + l.trainable = True + v2 = resource_variable_ops.ResourceVariable(1., trainable=False) + l.append(v2) + self.assertEqual([v, v2], l.weights) + self.assertEqual([v], l.trainable_weights) + self.assertEqual([v2], l.non_trainable_weights) + def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt index 8c8912dfab..23b552cc38 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -43,6 +43,10 @@ tf_class { name: "shape" mtype: "<type \'property\'>" } + member { + name: "trainable" + mtype: "<type \'property\'>" + } member_method { name: "__init__" argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " |