aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-30 19:01:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 19:04:42 -0700
commit5be69b0c5e0087acedffe4e94a716c0b5ed320fb (patch)
treef5a81988b6232161d5cccf7db210e2ae3e262683
parentd0f9424e22eb438f3d846fa62feaf331797e62c4 (diff)
Add a subclassed Model's attribute-assigned variables to Model.weights et al
Makes the Variable.trainable property public, which is sensible if we're discouraging use of the global collection (currently eager execution is using ResourceVariable._trainable in a bunch of places anyway). I'm leaving it read-only for now, since we should toggle in and out of the global collection when it changes. Same change for checkpointable data structures with respect to gathering extra variables. They'll behave like subclassed Models. I think this makes more sense than trying to have a distinction between "variables" and "weights". It's also more sensible than collecting everything that would get checkpointed, since that will include Optimizer slot variables and metrics. Collecting those is generally pointless, and accidentally adding them to gradient tapes would be horribly confusing. PiperOrigin-RevId: 198656079
-rw-r--r--tensorflow/core/framework/variable.proto3
-rw-r--r--tensorflow/python/eager/function.py2
-rw-r--r--tensorflow/python/eager/graph_callable.py2
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc4
-rw-r--r--tensorflow/python/keras/engine/network.py52
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py45
-rw-r--r--tensorflow/python/keras/utils/layer_utils.py55
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py19
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py17
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py8
-rw-r--r--tensorflow/python/ops/variable_scope.py6
-rw-r--r--tensorflow/python/ops/variables.py7
-rw-r--r--tensorflow/python/training/checkpointable/data_structures.py36
-rw-r--r--tensorflow/python/training/checkpointable/data_structures_test.py19
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-variable.pbtxt4
15 files changed, 233 insertions, 46 deletions
diff --git a/tensorflow/core/framework/variable.proto b/tensorflow/core/framework/variable.proto
index 93ae423bab..66ba4cba7d 100644
--- a/tensorflow/core/framework/variable.proto
+++ b/tensorflow/core/framework/variable.proto
@@ -26,6 +26,9 @@ message VariableDef {
// Whether to represent this as a ResourceVariable.
bool is_resource = 5;
+
+ // Whether this variable should be trained.
+ bool trainable = 7;
}
message SaveSliceInfoDef {
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 23d87fb394..559063d6ae 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -494,7 +494,7 @@ class GraphModeFunction(object):
def __call__(self, *args):
"""Executes the passed function in eager mode."""
for v in self._variables:
- if v._trainable: # pylint: disable=protected-access
+ if v.trainable:
tape.watch_variable(v)
tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index d9ffcbd203..760a148552 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -202,7 +202,7 @@ class _InitializingFunctionObject(object):
v.handle).numpy() for v in self._call_fn.variables]
if all(x for x in initialized):
for v in self._call_fn.variables:
- if v._trainable: # pylint: disable=protected-access
+ if v.trainable:
tape.watch_variable(v)
return self._call_fn(*args)
elif all(not x for x in initialized):
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 52b90504f3..e3ce0ef9d0 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1874,10 +1874,10 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
void MaybeWatchVariable(PyObject* input) {
DCHECK(CheckResourceVariable(input));
- DCHECK(PyObject_HasAttrString(input, "_trainable"));
+ DCHECK(PyObject_HasAttrString(input, "trainable"));
tensorflow::Safe_PyObjectPtr trainable(
- PyObject_GetAttrString(input, "_trainable"));
+ PyObject_GetAttrString(input, "trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeSetWatchVariable(input);
}
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 6db41472b6..f63ca1a207 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -36,9 +36,10 @@ from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.utils import generic_utils
+from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
-from tensorflow.python.keras.utils.layer_utils import print_summary as print_layer_summary
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures_base
@@ -94,6 +95,11 @@ class Network(base_layer.Layer):
self.trainable = True
self._is_compiled = False
self._expects_training_arg = False
+ # A list of "extra" variables assigned to attributes of this class, included
+ # in self.weights and self.variables. Always empty for graph networks (but
+ # included in base_init to avoid excessive special casing when retrieving
+ # the value).
+ self._extra_variables = []
self.supports_masking = False
if not hasattr(self, 'optimizer'):
@@ -347,11 +353,22 @@ class Network(base_layer.Layer):
# layers). Therefore Model tracks Checkpointable objects itself.
self._track_checkpointable(
checkpointable=value, name=name, overwrite=True)
+ if ( # For subclassed models only, users may add extra weights/variables
+ # simply by assigning them to attributes.
+ not self._is_graph_network
+ and isinstance(value, variables.Variable)):
+ self._extra_variables.append(value)
super(Network, self).__setattr__(name, value)
def add_variable(self, name, shape, dtype=None, initializer=None,
regularizer=None, trainable=True, constraint=None):
- raise NotImplementedError('`add_variable` is not supported on Networks.')
+ if self._is_graph_network:
+ raise NotImplementedError('`add_variable` is not supported on Networks.')
+ else:
+ raise NotImplementedError(
+ '`add_variable` is not supported on Networks. However, you may '
+ 'assign variables to attributes and they will show up in the weights '
+ 'and variables properties.')
def add_loss(self, *args, **kwargs):
if context.executing_eagerly():
@@ -589,24 +606,17 @@ class Network(base_layer.Layer):
@property
def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for layer in self.layers:
- weights += layer.trainable_weights
- return weights
+ return layer_utils.gather_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- weights = []
- for layer in self.layers:
- weights += layer.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for layer in self.layers:
- trainable_weights += layer.trainable_weights
- return trainable_weights + weights
- return weights
+ return layer_utils.gather_non_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def input_spec(self):
@@ -1437,10 +1447,10 @@ class Network(base_layer.Layer):
'have not yet been created, so no summary can be '
'displayed. Build the model first '
'(e.g. by calling it on some data).')
- print_layer_summary(self,
- line_length=line_length,
- positions=positions,
- print_fn=print_fn)
+ layer_utils.print_summary(self,
+ line_length=line_length,
+ positions=positions,
+ print_fn=print_fn)
def get_source_inputs(tensor, layer=None, node_index=None):
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index 558854ab97..86f7e20bec 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -622,6 +622,51 @@ class ModelSubclassingTest(test.TestCase):
self.assertIs(m.isdep, m._checkpoint_dependencies[0].ref)
self.assertEqual('notdep_var:0', m.notdep_var.name)
+ def test_extra_variable(self):
+
+ class ExtraVar(keras.Model):
+
+ def __init__(self):
+ super(ExtraVar, self).__init__()
+ self.dense = keras.layers.Dense(1)
+ self.var = resource_variable_ops.ResourceVariable(1.)
+ self.not_trainable_var = resource_variable_ops.ResourceVariable(
+ 2., trainable=False)
+
+ def call(self, inputs):
+ return self.dense(inputs + self.var)
+
+ m = ExtraVar()
+ self.assertTrue(m.trainable)
+ self.assertEqual([m.dense], m.layers)
+ self.assertEqual([m.var, m.not_trainable_var], m.variables)
+ self.assertEqual([m.var], m.trainable_variables)
+ self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+ m.trainable = False
+ self.assertEqual([m.var, m.not_trainable_var], m.variables)
+ self.assertEqual([], m.trainable_variables)
+ self.assertEqual([m.var, m.not_trainable_var], m.non_trainable_variables)
+ m.trainable = True
+
+ m(array_ops.ones([1, 1]))
+
+ self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias], m.dense.weights)
+
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.var, m.not_trainable_var],
+ m.variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.var],
+ m.trainable_variables)
+ self.assertEqual([m.not_trainable_var], m.non_trainable_variables)
+
+ m.dense.trainable = False
+ self.assertEqual(
+ [m.var, m.dense.kernel, m.dense.bias, m.not_trainable_var],
+ m.variables)
+ self.assertEqual([m.var], m.trainable_variables)
+ self.assertEqual([m.dense.kernel, m.dense.bias, m.not_trainable_var],
+ m.non_trainable_variables)
+
class CustomCallModel(keras.Model):
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index bd61f8e9cc..88daff0461 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -201,6 +201,61 @@ def print_summary(model, line_length=None, positions=None, print_fn=None):
print_fn('_' * line_length)
+def gather_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected trainable weights/variables.
+ """
+ if not trainable:
+ return []
+ weights = []
+ for layer in sub_layers:
+ weights += layer.trainable_weights
+ trainable_extra_variables = [
+ v for v in extra_variables if v.trainable]
+ return weights + trainable_extra_variables
+
+
+def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
+ """Lists the non-trainable weights for an object with sub-layers.
+
+ Args:
+ trainable: Whether the object collecting the variables is trainable.
+ sub_layers: A flat list of Layer objects owned by this object, to collect
+ variables from.
+ extra_variables: Any extra variables to include. Their `.trainable` property
+ is used to categorize them.
+
+ Returns:
+ A list of collected non-trainable weights/variables.
+ """
+ trainable_extra_variables = []
+ non_trainable_extra_variables = []
+ for v in extra_variables:
+ if v.trainable:
+ trainable_extra_variables.append(v)
+ else:
+ non_trainable_extra_variables.append(v)
+ weights = []
+ for layer in sub_layers:
+ weights += layer.non_trainable_weights
+ if not trainable:
+ trainable_weights = []
+ for layer in sub_layers:
+ trainable_weights += layer.trainable_weights
+ return (trainable_weights + trainable_extra_variables
+ + weights + non_trainable_extra_variables)
+ return weights + non_trainable_extra_variables
+
+
@tf_export('keras.utils.convert_all_kernels_in_model')
def convert_all_kernels_in_model(model):
"""Converts all convolution kernels in a model from Theano to TensorFlow.
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 972fbdb3d6..00d517e64e 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -538,6 +538,25 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
sess.run(v.initialized_value())
+ def testTrainableInProto(self):
+ with ops.Graph().as_default():
+ non_trainable_variable = resource_variable_ops.ResourceVariable(
+ trainable=False,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ False,
+ resource_variable_ops.ResourceVariable(
+ variable_def=non_trainable_variable.to_proto())
+ .trainable)
+ trainable_variable = resource_variable_ops.ResourceVariable(
+ trainable=True,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ True,
+ resource_variable_ops.ResourceVariable(
+ variable_def=trainable_variable.to_proto())
+ .trainable)
+
@test_util.run_in_graph_and_eager_modes()
def testSparseRead(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 27599868b7..62d596da91 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -496,6 +496,23 @@ class VariablesTestCase(test.TestCase):
with self.assertRaises(ValueError):
sess.run(v.initialized_value())
+ def testTrainableInProto(self):
+ with ops.Graph().as_default():
+ non_trainable_variable = variables.Variable(
+ trainable=False,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ False,
+ variables.Variable(variable_def=non_trainable_variable.to_proto())
+ .trainable)
+ trainable_variable = variables.Variable(
+ trainable=True,
+ initial_value=constant_op.constant(10.0))
+ self.assertEqual(
+ True,
+ variables.Variable(variable_def=trainable_variable.to_proto())
+ .trainable)
+
def testLoad(self):
with self.test_session():
var = variables.Variable(np.zeros((5, 5), np.float32))
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index e37e93ea35..7061b32808 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable):
import_scope=import_scope))
else:
self._initial_value = None
+ self._trainable = getattr(variable_def, "trainable", True)
if variable_def.snapshot_name:
snapshot = g.as_graph_element(
ops.prepend_name_scope(
@@ -735,7 +736,7 @@ class ResourceVariable(variables.Variable):
return self._save_slice_info
def _read_variable_op(self):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
@@ -760,7 +761,7 @@ class ResourceVariable(variables.Variable):
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
- if self._trainable:
+ if self.trainable:
tape.watch_variable(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
@@ -801,6 +802,7 @@ class ResourceVariable(variables.Variable):
var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
export_scope)
var_def.is_resource = True
+ var_def.trainable = self.trainable
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
self._save_slice_info.to_proto(export_scope=export_scope))
@@ -913,7 +915,7 @@ class ResourceVariable(variables.Variable):
return assign_add_op
def _lazy_read(self, op):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return _UnreadVariable(
self._handle, self.dtype, self._shape, self._in_graph_mode,
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 8d93d24b14..fa34774622 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1261,13 +1261,13 @@ class EagerVariableStore(object):
def trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if x._trainable],
+ return sorted([x for x in self._store._vars.values() if x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
def non_trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if not x._trainable],
+ return sorted([x for x in self._store._vars.values() if not x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
@@ -1296,7 +1296,7 @@ class EagerVariableStore(object):
new_var = resource_variable_ops.ResourceVariable(
var.read_value(),
name=stripped_var_name,
- trainable=var._trainable)
+ trainable=var.trainable)
new_store._store._vars[key] = new_var
return new_store
# pylint: enable=protected-access
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index d88fd836f5..4be9f5eb68 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -341,6 +341,7 @@ class Variable(checkpointable.CheckpointableBase):
self._update_uid = initial_value.checkpoint_position.restore_uid
initial_value = initial_value.wrapped_value
+ self._trainable = trainable
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.init_scope():
@@ -450,6 +451,7 @@ class Variable(checkpointable.CheckpointableBase):
import_scope=import_scope))
else:
self._initial_value = None
+ self._trainable = getattr(variable_def, "trainable", True)
self._snapshot = g.as_graph_element(
ops.prepend_name_scope(variable_def.snapshot_name,
import_scope=import_scope))
@@ -543,6 +545,10 @@ class Variable(checkpointable.CheckpointableBase):
self._ref().set_shape(shape)
self.value().set_shape(shape)
+ @property
+ def trainable(self):
+ return self._trainable
+
def eval(self, session=None):
"""In a session, computes and returns the value of this variable.
@@ -1050,6 +1056,7 @@ class Variable(checkpointable.CheckpointableBase):
# For backwards compatibility.
var_def.initial_value_name = ops.strip_name_scope(
self._initial_value.name, export_scope)
+ var_def.trainable = self.trainable
var_def.initializer_name = ops.strip_name_scope(
self.initializer.name, export_scope)
var_def.snapshot_name = ops.strip_name_scope(
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index 62cefa4f20..69ed253fb2 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -22,6 +22,8 @@ import collections
import six
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.utils import layer_utils
+from tensorflow.python.ops import variables
from tensorflow.python.training.checkpointable import base as checkpointable_lib
from tensorflow.python.training.checkpointable import data_structures_base
@@ -41,11 +43,14 @@ class CheckpointableDataStructure(
def __init__(self):
self._layers = []
self.trainable = True
+ self._extra_variables = []
def _track_value(self, value, name):
"""Add a dependency on `value`."""
if isinstance(value, checkpointable_lib.CheckpointableBase):
self._track_checkpointable(value, name=name)
+ if isinstance(value, variables.Variable):
+ self._extra_variables.append(value)
else:
raise ValueError(
("Only checkpointable objects (such as Layers or Optimizers) may be "
@@ -67,30 +72,31 @@ class CheckpointableDataStructure(
@property
def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for layer in self.layers:
- weights += layer.trainable_weights
- return weights
+ return layer_utils.gather_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def non_trainable_weights(self):
- weights = []
- for layer in self.layers:
- weights += layer.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for layer in self.layers:
- trainable_weights += layer.trainable_weights
- return trainable_weights + weights
- return weights
+ return layer_utils.gather_non_trainable_weights(
+ trainable=self.trainable,
+ sub_layers=self.layers,
+ extra_variables=self._extra_variables)
@property
def weights(self):
return self.trainable_weights + self.non_trainable_weights
@property
+ def trainable_variables(self):
+ return self.trainable_weights
+
+ @property
+ def non_trainable_variables(self):
+ return self.non_trainable_weights
+
+ @property
def variables(self):
return self.weights
diff --git a/tensorflow/python/training/checkpointable/data_structures_test.py b/tensorflow/python/training/checkpointable/data_structures_test.py
index 31a0e8b622..b05b3a8800 100644
--- a/tensorflow/python/training/checkpointable/data_structures_test.py
+++ b/tensorflow/python/training/checkpointable/data_structures_test.py
@@ -139,6 +139,25 @@ class ListTests(test.TestCase):
outer.variables[0],
resource_variable_ops.ResourceVariable)
+ def testNonLayerVariables(self):
+ v = resource_variable_ops.ResourceVariable([1.])
+ l = data_structures.List([v])
+ self.assertTrue(l.trainable)
+ self.assertEqual([], l.layers)
+ self.assertEqual([v], l.variables)
+ self.assertEqual([v], l.trainable_weights)
+ self.assertEqual([], l.non_trainable_variables)
+ l.trainable = False
+ self.assertEqual([v], l.variables)
+ self.assertEqual([], l.trainable_variables)
+ self.assertEqual([v], l.non_trainable_variables)
+ l.trainable = True
+ v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
+ l.append(v2)
+ self.assertEqual([v, v2], l.weights)
+ self.assertEqual([v], l.trainable_weights)
+ self.assertEqual([v2], l.non_trainable_weights)
+
def testHashing(self):
has_sequences = set([data_structures.List(),
data_structures.List()])
diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
index 8c8912dfab..23b552cc38 100644
--- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt
@@ -43,6 +43,10 @@ tf_class {
name: "shape"
mtype: "<type \'property\'>"
}
+ member {
+ name: "trainable"
+ mtype: "<type \'property\'>"
+ }
member_method {
name: "__init__"
argspec: "args=[\'self\', \'initial_value\', \'trainable\', \'collections\', \'validate_shape\', \'caching_device\', \'name\', \'variable_def\', \'dtype\', \'expected_shape\', \'import_scope\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'True\', \'None\', \'True\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "