diff options
Diffstat (limited to 'tensorflow/python/ops/resource_variable_ops.py')
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index e37e93ea35..7061b32808 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable): import_scope=import_scope)) else: self._initial_value = None + self._trainable = getattr(variable_def, "trainable", True) if variable_def.snapshot_name: snapshot = g.as_graph_element( ops.prepend_name_scope( @@ -735,7 +736,7 @@ class ResourceVariable(variables.Variable): return self._save_slice_info def _read_variable_op(self): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return gen_resource_variable_ops.read_variable_op(self._handle, self._dtype) @@ -760,7 +761,7 @@ class ResourceVariable(variables.Variable): def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: - if self._trainable: + if self.trainable: tape.watch_variable(self) value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) @@ -801,6 +802,7 @@ class ResourceVariable(variables.Variable): var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name, export_scope) var_def.is_resource = True + var_def.trainable = self.trainable if self._save_slice_info: var_def.save_slice_info_def.MergeFrom( self._save_slice_info.to_proto(export_scope=export_scope)) @@ -913,7 +915,7 @@ class ResourceVariable(variables.Variable): return assign_add_op def _lazy_read(self, op): - if hasattr(self, "_trainable") and self._trainable: + if self.trainable: tape.watch_variable(self) return _UnreadVariable( self._handle, self.dtype, self._shape, self._in_graph_mode, |