aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/template.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/template.py')
-rw-r--r--tensorflow/python/ops/template.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 161d9687d6..e7ad261615 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -128,7 +128,7 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None,
template of the same scope/unique_name already exists and reuse is false,
an error is raised. Defaults to None.
custom_getter_: Optional custom getter for variables used in `func_`. See
- the @{tf.get_variable} `custom_getter` documentation for
+ the `tf.get_variable` `custom_getter` documentation for
more information.
**kwargs: Keyword arguments to apply to `func_`.
@@ -176,7 +176,7 @@ def make_template_internal(name_,
template of the same scope/unique_name already exists and reuse is false,
an error is raised. Defaults to None. If executing eagerly, must be None.
custom_getter_: Optional custom getter for variables used in `func_`. See
- the @{tf.get_variable} `custom_getter` documentation for
+ the `tf.get_variable` `custom_getter` documentation for
more information.
create_graph_function_: When True, `func_` will be executed as a graph
function. This implies that `func_` must satisfy the properties that
@@ -298,9 +298,10 @@ class Template(checkpointable.CheckpointableBase):
def _call_func(self, args, kwargs):
try:
- vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ vars_at_start = len(
+ ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
trainable_at_start = len(
- ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+ ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
if self._variables_created:
result = self._func(*args, **kwargs)
else:
@@ -313,7 +314,7 @@ class Template(checkpointable.CheckpointableBase):
# Variables were previously created, implying this is not the first
# time the template has been called. Check to make sure that no new
# trainable variables were created this time around.
- trainable_variables = ops.get_collection(
+ trainable_variables = ops.get_collection_ref(
ops.GraphKeys.TRAINABLE_VARIABLES)
# If a variable that we intend to train is created as a side effect
# of creating a template, then that is almost certainly an error.
@@ -326,7 +327,7 @@ class Template(checkpointable.CheckpointableBase):
# Non-trainable tracking variables are a legitimate reason why a new
# variable would be created, but it is a relatively advanced use-case,
# so log it.
- variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)
if vars_at_start != len(variables):
logging.info("New variables created when calling a template after "
"the first time, perhaps you used tf.Variable when you "