aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/warm_starting_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/warm_starting_util.py')
-rw-r--r--tensorflow/python/training/warm_starting_util.py18
1 files changed, 7 insertions, 11 deletions
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py
index ec740abdd1..b1a7cfab83 100644
--- a/tensorflow/python/training/warm_starting_util.py
+++ b/tensorflow/python/training/warm_starting_util.py
@@ -22,7 +22,6 @@ import collections
import six
from tensorflow.python.framework import ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
@@ -83,11 +82,6 @@ class VocabInfo(
)
-def _is_variable(x):
- return (isinstance(x, variables_lib.Variable) or
- isinstance(x, resource_variable_ops.ResourceVariable))
-
-
def _infer_var_name(var):
"""Returns name of the `var`.
@@ -126,9 +120,10 @@ def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
None, we lookup tensor with same name as given `var`.
"""
- if _is_variable(var):
+ if checkpoint_utils._is_variable(var): # pylint: disable=protected-access
current_var_name = _infer_var_name([var])
- elif isinstance(var, list) and all(_is_variable(v) for v in var):
+ elif (isinstance(var, list) and
+ all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access
current_var_name = _infer_var_name(var)
elif isinstance(var, variables_lib.PartitionedVariable):
current_var_name = _infer_var_name([var])
@@ -193,9 +188,10 @@ def _warm_start_var_with_vocab(var,
prev_vocab_path):
raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
"current_vocab_size, prev_ckpt, prev_vocab_path}.")
- if _is_variable(var):
+ if checkpoint_utils._is_variable(var):
var = [var]
- elif isinstance(var, list) and all(_is_variable(v) for v in var):
+ elif (isinstance(var, list) and
+ all(checkpoint_utils._is_variable(v) for v in var)):
var = var
elif isinstance(var, variables_lib.PartitionedVariable):
var = var._get_variable_list()
@@ -271,7 +267,7 @@ def _get_grouped_variables(vars_to_warm_start):
for v in vars_to_warm_start:
list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
scope=v)
- elif all([_is_variable(v) for v in vars_to_warm_start]):
+ elif all([checkpoint_utils._is_variable(v) for v in vars_to_warm_start]): # pylint: disable=protected-access
list_of_vars = vars_to_warm_start
else:
raise ValueError("If `vars_to_warm_start` is a list, it must be all "