aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/framework/variable.proto3
-rw-r--r--tensorflow/python/eager/function.py2
-rw-r--r--tensorflow/python/eager/graph_callable.py2
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc4
-rw-r--r--tensorflow/python/keras/engine/network.py52
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py45
-rw-r--r--tensorflow/python/keras/utils/layer_utils.py55
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py19
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py17
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py8
-rw-r--r--tensorflow/python/ops/variable_scope.py6
-rw-r--r--tensorflow/python/ops/variables.py7
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py36
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py19
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable.pbtxt4
15 files changed, 233 insertions, 46 deletions
diff --git a/tensorflow/core/framework/variable.proto b/tensorflow/core/framework/variable.proto
index 93ae423bab..66ba4cba7d 100644
--- a/tensorflow/core/framework/variable.proto
+++ b/tensorflow/core/framework/variable.proto
@@ -26,6 +26,9 @@ message VariableDef {
// Whether to represent this as a ResourceVariable.
bool is_resource = 5;
+
+ // Whether this variable should be trained.
+ bool trainable = 7;
}
message SaveSliceInfoDef {
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 23d87fb394..559063d6ae 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -494,7 +494,7 @@ class GraphModeFunction(object):
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
- if v._trainable: # pylint: disable=protected-access
+ if v.trainable:
tape.watch_variable(v)
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index d9ffcbd203..760a148552 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -202,7 +202,7 @@ class _InitializingFunctionObject(object):
v.handle).numpy() for v in self._call_fn.variables]
if all(x for x in initialized):
for v in self._call_fn.variables:
- if v._trainable: # pylint: disable=protected-access
+ if v.trainable:
tape.watch_variable(v)
return self._call_fn(*args)
elif all(not x for x in initialized):
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 52b90504f3..e3ce0ef9d0 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1874,10 +1874,10 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
void MaybeWatchVariable(PyObject* input) {
DCHECK(CheckResourceVariable(input));
- DCHECK(PyObject_HasAttrString(input, "_trainable"));
+ DCHECK(PyObject_HasAttrString(input, "trainable"));
tensorflow::Safe_PyObjectPtr trainable(
- PyObject_GetAttrString(input, "_trainable"));
+ PyObject_GetAttrString(input, "trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeSetWatchVariable(input);
}
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 6db41472b6..f63ca1a207 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -36,9 +36,10 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
-from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures_base
@@ -94,6 +95,11 @@ class Network(base_layer.Layer):
self.trainable = True
self._is_compiled = False
self._expects_training_arg = False
+ # A list of "extra" variables assigned to attributes of this class, included
+ # in self.weights and self.variables. Always empty for graph networks (but
+ # included in base_init to avoid excessive special casing when retrieving
+ # the value).
+ self._extra_variables = []
self.supports_masking = False
if not hasattr(self, 'optimizer'):
@@ -347,11 +353,22 @@ class Network(base_layer.Layer):
# layers). Therefore Model tracks Checkpointable objects itself.
self._track_checkpointable(
checkpointable=value, name=name, overwrite=True)
+ if ( # For subclassed models only, users may add extra weights/variables
+ # simply by assigning them to attributes.
+ not self._is_graph_network
+ and isinstance(value, variables.Variable)):
+ self._extra_variables.append(value)
super(Network, self).__setattr__(name, value)
def add_variable(self, name, shape, dtype=None, initializer=None,
regularizer=None, trainable=True, constraint=None):
- raise NotImplementedError('`add_variable` is not supported on Networks.')
+ if self._is_graph_network:
+ raise NotImplementedError('`add_variable` is not supported on Networks.')
+ else:
+ raise NotImplementedError(
+ '`add_variable` is not supported on Networks. However, you may '
+ 'assign variables to attributes and they will show up in the weights '
+ 'and variables properties.')
def add_loss(self, *args, **kwargs):
if context.executing_eagerly():
@@ -589,24 +606,17 @@ class Network(base_layer.Layer):
@property
def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for layer in self.layers:
- weights += layer.trainable_weights
- return weights
+ return layer_utils.gather_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- weights = []
- for layer in self.layers:
- weights += layer.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for layer in self.layers:
- trainable_weights += layer.trainable_weights
- return trainable_weights + weights
- return weights
+ return layer_utils.gather_non_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def input_spec(self):
@@ -1437,10 +1447,10 @@ class Network(base_layer.Layer):
'have not yet been created, so no summary can be '
'displayed. Build the model first '
'(e.g. by calling it on some data).')
- print_layer_summary(self,
- line_length=line_length,
- positions=positions,
- print_fn=print_fn)
+ layer_utils.print_summary(self,
+ line_length=line_length,
+ positions=positions,
+ print_fn=print_fn)
def get_source_inputs(tensor, layer=None, node_index=None):
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 558854ab97..86f7e20bec 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -622,6 +622,51 @@ class ModelSubclassingTest(test.TestCase):
self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref)
self.assertEqual('notdep_var:0', m.notdep_var.name)
+ def test_extra_variable(self):
+
+ class ExtraVar(keras.Model):
+
+ def __init__(self):
+ super(ExtraVar, self).__init__()
+ self.dense = keras.layers.Dense(1)
+ self.var = resource_variable_ops.ResourceVariable(1.)
+ self.not_trainable_var = resource_variable_ops.ResourceVariable(
+ 2., trainable=False)
+
+ def call(self, inputs):
+ return self.dense(inputs + self.var)
+
+ m = ExtraVar()
+ self.assertTrue(m.trainable)
+ self.assertEqual([m.dense], m.layers)
+ self.assertEqual([m.var, m.not_trainable_var], m.variables)
+ self.assertEqual([m.var], m.trainable_variables)
+ self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+ m.trainable = False
+ self.assertEqual([m.var, m.not_trainable_var], m.variables)
+ self.assertEqual([], m.trainable_variables)
+ self.assertEqual([m.var, m.not_trainable_var], m.non_trainable_variables)
+ m.trainable = True
+
+ m(array_ops.ones([1, 1]))
+
+ self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.weights)
+
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.var, m.not_trainable_var],
+ m.variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.var],
+ m.trainable_variables)
+ self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+
+ m.dense.trainable = False
+ self.assertEqual(
+ [m.var, m.dense.kernel, m.dense.bias, m.not_trainable_var],
+ m.variables)
+ self.assertEqual([m.var], m.trainable_variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.not_trainable_var],
+ m.non_trainable_variables)
+
class CustomCallModel(keras.Model):
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index bd61f8e9cc..88daff0461 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -201,6 +201,61 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
print_fn('_' * line_length)
+def gather_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected trainable weights/variables.
+ """
+ if not trainable:
+ return []
+ weights = []
+ for layer in sub_layers:
+ weights += layer.trainable_weights
+ trainable_extra_variables = [
+ v for v in extra_variables if v.trainable]
+ return weights + trainable_extra_variables
+
+
+def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the non-trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected non-trainable weights/variables.
+ """
+ trainable_extra_variables = []
+ non_trainable_extra_variables = []
+ for v in extra_variables:
+ if v.trainable:
+ trainable_extra_variables.append(v)
+ else:
+ non_trainable_extra_variables.append(v)
+ weights = []
+ for layer in sub_layers:
+ weights += layer.non_trainable_weights
+ if not trainable:
+ trainable_weights = []
+ for layer in sub_layers:
+ trainable_weights += layer.trainable_weights
+ return (trainable_weights + trainable_extra_variables
+ + weights + non_trainable_extra_variables)
+ return weights + non_trainable_extra_variables
+
+
@tf_export('keras.utils.convert_all_kernels_in_model')
def convert_all_kernels_in_model(model):
"""Converts all convolution kernels in a model from Theano to TensorFlow.
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 972fbdb3d6..00d517e64e 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -538,6 +538,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
sess.run(v.initialized_value())
+ def testTrainableInProto(self):
+ with ops.Graph().as_default():
+ non_trainable_variable = resource_variable_ops.ResourceVariable(
+ trainable=False,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ False,
+ resource_variable_ops.ResourceVariable(
+ variable_def=non_trainable_variable.to_proto())
+ .trainable)
+ trainable_variable = resource_variable_ops.ResourceVariable(
+ trainable=True,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ True,
+ resource_variable_ops.ResourceVariable(
+ variable_def=trainable_variable.to_proto())
+ .trainable)
+
@test_util.run_in_graph_and_eager_modes()
def testSparseRead(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 27599868b7..62d596da91 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -496,6 +496,23 @@ class VariablesTestCase(test.TestCase):
with self.assertRaises(ValueError):
sess.run(v.initialized_value())
+ def testTrainableInProto(self):
+ with ops.Graph().as_default():
+ non_trainable_variable = variables.Variable(
+ trainable=False,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ False,
+ variables.Variable(variable_def=non_trainable_variable.to_proto())
+ .trainable)
+ trainable_variable = variables.Variable(
+ trainable=True,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ True,
+ variables.Variable(variable_def=trainable_variable.to_proto())
+ .trainable)
+
def testLoad(self):
with self.test_session():
var = variables.Variable(np.zeros((5, 5), np.float32))
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index e37e93ea35..7061b32808 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable):
import_scope=import_scope))
else:
self._initial_value = None
+ self._trainable = getattr(variable_def, "trainable", True)
if variable_def.snapshot_name:
snapshot = g.as_graph_element(
ops.prepend_name_scope(
@@ -735,7 +736,7 @@ class ResourceVariable(variables.Variable):
return self._save_slice_info
def _read_variable_op(self):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
@@ -760,7 +761,7 @@ class ResourceVariable(variables.Variable):
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
- if self._trainable:
+ if self.trainable:
tape.watch_variable(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
@@ -801,6 +802,7 @@ class ResourceVariable(variables.Variable):
var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
export_scope)
var_def.is_resource = True
+ var_def.trainable = self.trainable
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
self._save_slice_info.to_proto(export_scope=export_scope))
@@ -913,7 +915,7 @@ class ResourceVariable(variables.Variable):
return assign_add_op
def _lazy_read(self, op):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return _UnreadVariable(
self._handle, self.dtype, self._shape, self._in_graph_mode,
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 8d93d24b14..fa34774622 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1261,13 +1261,13 @@ class EagerVariableStore(object):
def trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if x._trainable],
+ return sorted([x for x in self._store._vars.values() if x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
def non_trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if not x._trainable],
+ return sorted([x for x in self._store._vars.values() if not x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
@@ -1296,7 +1296,7 @@ class EagerVariableStore(object):
new_var = resource_variable_ops.ResourceVariable(
var.read_value(),
name=stripped_var_name,
- trainable=var._trainable)
+ trainable=var.trainable)
new_store._store._vars[key] = new_var
return new_store
# pylint: enable=protected-access
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index d88fd836f5..4be9f5eb68 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -341,6 +341,7 @@ class Variable(checkpointable.CheckpointableBase):
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
+ self._trainable = trainable
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.init_scope():
@@ -450,6 +451,7 @@ class Variable(checkpointable.CheckpointableBase):
import_scope=import_scope))
else:
self._initial_value = None
+ self._trainable = getattr(variable_def, "trainable", True)
self._snapshot = g.as_graph_element(
ops.prepend_name_scope(variable_def.snapshot_name,
import_scope=import_scope))
@@ -543,6 +545,10 @@ class Variable(checkpointable.CheckpointableBase):
self._ref().set_shape(shape)
self.value().set_shape(shape)
+ @property
+ def trainable(self):
+ return self._trainable
+
def eval(self, session=None):
"""In a session, computes and returns the value of this variable.
@@ -1050,6 +1056,7 @@ class Variable(checkpointable.CheckpointableBase):
# For backwards compatibility.
var_def.initial_value_name = ops.strip_name_scope(
self._initial_value.name, export_scope)
+ var_def.trainable = self.trainable
var_def.initializer_name = ops.strip_name_scope(
self.initializer.name, export_scope)
var_def.snapshot_name = ops.strip_name_scope(
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 62cefa4f20..69ed253fb2 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -22,6 +22,8 @@ import collections
import six
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.utils import layer_utils
+from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import base as checkpointable_lib
from tensorflow.python.training.checkpointable import data_structures_base
@@ -41,11 +43,14 @@ class CheckpointableDataStructure(
def __init__(self):
self._layers = []
self.trainable = True
+ self._extra_variables = []
def _track_value(self, value, name):
"""Add a dependency on `value`."""
if isinstance(value, checkpointable_lib.CheckpointableBase):
self._track_checkpointable(value, name=name)
+ if isinstance(value, variables.Variable):
+ self._extra_variables.append(value)
else:
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
@@ -67,30 +72,31 @@ class CheckpointableDataStructure(
@property
def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for layer in self.layers:
- weights += layer.trainable_weights
- return weights
+ return layer_utils.gather_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- weights = []
- for layer in self.layers:
- weights += layer.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for layer in self.layers:
- trainable_weights += layer.trainable_weights
- return trainable_weights + weights
- return weights
+ return layer_utils.gather_non_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def weights(self):
return self.trainable_weights + self.non_trainable_weights
@property
+ def trainable_variables(self):
+ return self.trainable_weights
+
+ @property
+ def non_trainable_variables(self):
+ return self.non_trainable_weights
+
+ @property
def variables(self):
return self.weights
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 31a0e8b622..b05b3a8800 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -139,6 +139,25 @@ class ListTests(test.TestCase):
outer.variables[0],
resource_variable_ops.ResourceVariable)
+ def testNonLayerVariables(self):
+ v = resource_variable_ops.ResourceVariable([1.])
+ l = data_structures.List([v])
+ self.assertTrue(l.trainable)
+ self.assertEqual([], l.layers)
+ self.assertEqual([v], l.variables)
+ self.assertEqual([v], l.trainable_weights)
+ self.assertEqual([], l.non_trainable_variables)
+ l.trainable = False
+ self.assertEqual([v], l.variables)
+ self.assertEqual([], l.trainable_variables)
+ self.assertEqual([v], l.non_trainable_variables)
+ l.trainable = True
+ v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
+ l.append(v2)
+ self.assertEqual([v, v2], l.weights)
+ self.assertEqual([v], l.trainable_weights)
+ self.assertEqual([v2], l.non_trainable_weights)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
index 8c8912dfab..23b552cc38 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
@@ -43,6 +43,10 @@ tf_class {
name: "shape"
mtype: "<type \'property\'>"
}
+ member {
+ name: "trainable"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "