aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variable_scope.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-30 19:01:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-30 19:04:42 -0700
commit5be69b0c5e0087acedffe4e94a716c0b5ed320fb (patch)
treef5a81988b6232161d5cccf7db210e2ae3e262683 /tensorflow/python/ops/variable_scope.py
parentd0f9424e22eb438f3d846fa62feaf331797e62c4 (diff)
Add a subclassed Model's attribute-assigned variables to Model.weights et al
Makes the Variable.trainable property public, which is sensible if we're discouraging use of the global collection (currently eager execution is using ResourceVariable._trainable in a bunch of places anyway). I'm leaving it read-only for now, since we should toggle in and out of the global collection when it changes. Same change for checkpointable data structures with respect to gathering extra variables. They'll behave like subclassed Models. I think this makes more sense than trying to have a distinction between "variables" and "weights". It's also more sensible than collecting everything that would get checkpointed, since that will include Optimizer slot variables and metrics. Collecting those is generally pointless, and accidentally adding them to gradient tapes would be horribly confusing. PiperOrigin-RevId: 198656079
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r--tensorflow/python/ops/variable_scope.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 8d93d24b14..fa34774622 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1261,13 +1261,13 @@ class EagerVariableStore(object):
def trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if x._trainable],
+ return sorted([x for x in self._store._vars.values() if x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
def non_trainable_variables(self):
# pylint: disable=protected-access
- return sorted([x for x in self._store._vars.values() if not x._trainable],
+ return sorted([x for x in self._store._vars.values() if not x.trainable],
key=lambda x: x.name)
# pylint: enable=protected-access
@@ -1296,7 +1296,7 @@ class EagerVariableStore(object):
new_var = resource_variable_ops.ResourceVariable(
var.read_value(),
name=stripped_var_name,
- trainable=var._trainable)
+ trainable=var.trainable)
new_store._store._vars[key] = new_var
return new_store
# pylint: enable=protected-access