diff options
Diffstat (limited to 'tensorflow/python/training/warm_starting_util.py')
-rw-r--r-- | tensorflow/python/training/warm_starting_util.py | 18 |
1 files changed, 7 insertions, 11 deletions
diff --git a/tensorflow/python/training/warm_starting_util.py b/tensorflow/python/training/warm_starting_util.py index ec740abdd1..b1a7cfab83 100644 --- a/tensorflow/python/training/warm_starting_util.py +++ b/tensorflow/python/training/warm_starting_util.py @@ -22,7 +22,6 @@ import collections import six from tensorflow.python.framework import ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib @@ -83,11 +82,6 @@ class VocabInfo( ) -def _is_variable(x): - return (isinstance(x, variables_lib.Variable) or - isinstance(x, resource_variable_ops.ResourceVariable)) - - def _infer_var_name(var): """Returns name of the `var`. @@ -126,9 +120,10 @@ def _warm_start_var(var, prev_ckpt, prev_tensor_name=None): prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If None, we lookup tensor with same name as given `var`. """ - if _is_variable(var): + if checkpoint_utils._is_variable(var): # pylint: disable=protected-access current_var_name = _infer_var_name([var]) - elif isinstance(var, list) and all(_is_variable(v) for v in var): + elif (isinstance(var, list) and + all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access current_var_name = _infer_var_name(var) elif isinstance(var, variables_lib.PartitionedVariable): current_var_name = _infer_var_name([var]) @@ -193,9 +188,10 @@ def _warm_start_var_with_vocab(var, prev_vocab_path): raise ValueError("Invalid args: Must provide all of [current_vocab_path, " "current_vocab_size, prev_ckpt, prev_vocab_path}.") - if _is_variable(var): + if checkpoint_utils._is_variable(var): var = [var] - elif isinstance(var, list) and all(_is_variable(v) for v in var): + elif (isinstance(var, list) and + all(checkpoint_utils._is_variable(v) for v in var)): var = var elif isinstance(var, variables_lib.PartitionedVariable): var = var._get_variable_list() @@ -271,7 +267,7 @@ def _get_grouped_variables(vars_to_warm_start): for v in vars_to_warm_start: list_of_vars += ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope=v) - elif all([_is_variable(v) for v in vars_to_warm_start]): + elif all([checkpoint_utils._is_variable(v) for v in vars_to_warm_start]): # pylint: disable=protected-access list_of_vars = vars_to_warm_start else: raise ValueError("If `vars_to_warm_start` is a list, it must be all " |