diff options
author | 2018-05-30 19:01:58 -0700 | |
---|---|---|
committer | 2018-05-30 19:04:42 -0700 | |
commit | 5be69b0c5e0087acedffe4e94a716c0b5ed320fb (patch) | |
tree | f5a81988b6232161d5cccf7db210e2ae3e262683 /tensorflow/python/ops/variable_scope.py | |
parent | d0f9424e22eb438f3d846fa62feaf331797e62c4 (diff) |
Add a subclassed Model's attribute-assigned variables to Model.weights et al
Makes the Variable.trainable property public, which is sensible if we're discouraging use of the global collection (currently eager execution is using ResourceVariable._trainable in a bunch of places anyway). I'm leaving it read-only for now, since we should toggle in and out of the global collection when it changes.
Same change for checkpointable data structures with respect to gathering extra variables. They'll behave like subclassed Models.
I think this makes more sense than trying to have a distinction between "variables" and "weights". It's also more sensible than collecting everything that would get checkpointed, since that will include Optimizer slot variables and metrics. Collecting those is generally pointless, and accidentally adding them to gradient tapes would be horribly confusing.
PiperOrigin-RevId: 198656079
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 8d93d24b14..fa34774622 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1261,13 +1261,13 @@ class EagerVariableStore(object): def trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if x._trainable], + return sorted([x for x in self._store._vars.values() if x.trainable], key=lambda x: x.name) # pylint: enable=protected-access def non_trainable_variables(self): # pylint: disable=protected-access - return sorted([x for x in self._store._vars.values() if not x._trainable], + return sorted([x for x in self._store._vars.values() if not x.trainable], key=lambda x: x.name) # pylint: enable=protected-access @@ -1296,7 +1296,7 @@ class EagerVariableStore(object): new_var = resource_variable_ops.ResourceVariable( var.read_value(), name=stripped_var_name, - trainable=var._trainable) + trainable=var.trainable) new_store._store._vars[key] = new_var return new_store # pylint: enable=protected-access |