aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/resource_variable_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/resource_variable_ops.py')
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index e37e93ea35..7061b32808 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -551,6 +551,7 @@ class ResourceVariable(variables.Variable):
import_scope=import_scope))
else:
self._initial_value = None
+ self._trainable = getattr(variable_def, "trainable", True)
if variable_def.snapshot_name:
snapshot = g.as_graph_element(
ops.prepend_name_scope(
@@ -735,7 +736,7 @@ class ResourceVariable(variables.Variable):
return self._save_slice_info
def _read_variable_op(self):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
@@ -760,7 +761,7 @@ class ResourceVariable(variables.Variable):
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
- if self._trainable:
+ if self.trainable:
tape.watch_variable(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
@@ -801,6 +802,7 @@ class ResourceVariable(variables.Variable):
var_def.snapshot_name = ops.strip_name_scope(self._graph_element.name,
export_scope)
var_def.is_resource = True
+ var_def.trainable = self.trainable
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(
self._save_slice_info.to_proto(export_scope=export_scope))
@@ -913,7 +915,7 @@ class ResourceVariable(variables.Variable):
return assign_add_op
def _lazy_read(self, op):
- if hasattr(self, "_trainable") and self._trainable:
+ if self.trainable:
tape.watch_variable(self)
return _UnreadVariable(
self._handle, self.dtype, self._shape, self._in_graph_mode,