diff options
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 13 | ||||
-rw-r--r-- | tensorflow/python/util/deprecation.py | 33 |
2 files changed, 38 insertions, 8 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 91d900395b..2d40caa656 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -341,7 +341,8 @@ class BaseEstimator( return copy.deepcopy(self._config) @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('y', None), ('batch_size', None) ) def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None): @@ -367,7 +368,8 @@ class BaseEstimator( return self @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('y', None), ('batch_size', None) ) def partial_fit( self, x=None, y=None, input_fn=None, steps=1, batch_size=None, @@ -411,7 +413,8 @@ class BaseEstimator( batch_size=batch_size, monitors=monitors) @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('y', None), ('batch_size', None) ) def evaluate( self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None, @@ -442,8 +445,8 @@ class BaseEstimator( return eval_results @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'batch_size', - 'as_iterable' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('batch_size', None), ('as_iterable', True) ) def predict( self, x=None, input_fn=None, batch_size=None, outputs=None, diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 32d2678ce7..c4afed649e 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -224,6 +224,33 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): 'in the function signature: %s. ' 'Found next arguments: %s.' % (missing_args, known_args)) + def _same_value(a, b): + """A comparison operation that works for multiple object types. + + Returns True for two empty lists, two numeric values with the + same value, etc. + + Returns False for (pd.DataFrame, None), and other pairs which + should not be considered equivalent. + + Args: + a: value one of the comparison. + b: value two of the comparison. + + Returns: + A boolean indicating whether the two inputs are the same value + for the purposes of deprecation. + """ + if a is b: + return True + try: + equality = a == b + if isinstance(equality, bool): + return equality + except TypeError: + return False + return False + @functools.wraps(func) def new_func(*args, **kwargs): """Deprecation wrapper.""" @@ -232,7 +259,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): for arg_name, spec in iter(deprecated_positions.items()): if (spec.position < len(args) and not (spec.has_ok_value and - named_args[arg_name] == spec.ok_value)): + _same_value(named_args[arg_name], spec.ok_value))): invalid_args.append(arg_name) if is_varargs_deprecated and len(args) > len(arg_spec.args): invalid_args.append(arg_spec.varargs) @@ -241,8 +268,8 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): for arg_name in deprecated_arg_names: if (arg_name in kwargs and not (deprecated_positions[arg_name].has_ok_value and - (named_args[arg_name] == - deprecated_positions[arg_name].ok_value))): + _same_value(named_args[arg_name], + deprecated_positions[arg_name].ok_value))): invalid_args.append(arg_name) for arg_name in invalid_args: logging.warning( |