diff options
Diffstat (limited to 'tensorflow/python/ops/template.py')
-rw-r--r-- | tensorflow/python/ops/template.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 161d9687d6..e7ad261615 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -128,7 +128,7 @@ def make_template(name_, func_, create_scope_now_=False, unique_name_=None, template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. custom_getter_: Optional custom getter for variables used in `func_`. See - the @{tf.get_variable} `custom_getter` documentation for + the `tf.get_variable` `custom_getter` documentation for more information. **kwargs: Keyword arguments to apply to `func_`. @@ -176,7 +176,7 @@ def make_template_internal(name_, template of the same scope/unique_name already exists and reuse is false, an error is raised. Defaults to None. If executing eagerly, must be None. custom_getter_: Optional custom getter for variables used in `func_`. See - the @{tf.get_variable} `custom_getter` documentation for + the `tf.get_variable` `custom_getter` documentation for more information. create_graph_function_: When True, `func_` will be executed as a graph function. This implies that `func_` must satisfy the properties that @@ -298,9 +298,10 @@ class Template(checkpointable.CheckpointableBase): def _call_func(self, args, kwargs): try: - vars_at_start = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + vars_at_start = len( + ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) trainable_at_start = len( - ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)) + ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) if self._variables_created: result = self._func(*args, **kwargs) else: @@ -313,7 +314,7 @@ class Template(checkpointable.CheckpointableBase): # Variables were previously created, implying this is not the first # time the template has been called. Check to make sure that no new # trainable variables were created this time around. - trainable_variables = ops.get_collection( + trainable_variables = ops.get_collection_ref( ops.GraphKeys.TRAINABLE_VARIABLES) # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. @@ -326,7 +327,7 @@ class Template(checkpointable.CheckpointableBase): # Non-trainable tracking variables are a legitimate reason why a new # variable would be created, but it is a relatively advanced use-case, # so log it. - variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES) if vars_at_start != len(variables): logging.info("New variables created when calling a template after " "the first time, perhaps you used tf.Variable when you " |