diff options
author | 2016-12-15 10:57:47 -0800 | |
---|---|---|
committer | 2016-12-15 11:05:27 -0800 | |
commit | 5462618b916cad9aec46a9d1fd04ef353e9696b9 (patch) | |
tree | 32225e17aa28db659d679098d7929c3d1034fca0 | |
parent | b8e47b00507d2e9783821c13063e4d94c5cd9809 (diff) |
Mutes warnings for some deprecated arguments with default values. Fixes a bug.
Users were getting log warnings even when using the API correctly,
because of argument passthrough within tf.learn. This change mutes
warnings for calls to fit, estimate, and predict if the deprecated
arguments receive default values.
Also fixes a bug in the deprecation tool where "==" was used instead of "is" to compare objects, which was at times yeilding a non-boolean.
Change: 142162193
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/estimator.py | 13 | ||||
-rw-r--r-- | tensorflow/python/util/deprecation.py | 33 |
2 files changed, 38 insertions, 8 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 91d900395b..2d40caa656 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -341,7 +341,8 @@ class BaseEstimator( return copy.deepcopy(self._config) @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('y', None), ('batch_size', None) ) def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, monitors=None, max_steps=None): @@ -367,7 +368,8 @@ class BaseEstimator( return self @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('y', None), ('batch_size', None) ) def partial_fit( self, x=None, y=None, input_fn=None, steps=1, batch_size=None, @@ -411,7 +413,8 @@ class BaseEstimator( batch_size=batch_size, monitors=monitors) @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('y', None), ('batch_size', None) ) def evaluate( self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None, @@ -442,8 +445,8 @@ class BaseEstimator( return eval_results @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'batch_size', - 'as_iterable' + SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), + ('batch_size', None), ('as_iterable', True) ) def predict( self, x=None, input_fn=None, batch_size=None, outputs=None, diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py index 32d2678ce7..c4afed649e 100644 --- a/tensorflow/python/util/deprecation.py +++ b/tensorflow/python/util/deprecation.py @@ -224,6 +224,33 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): 'in the function signature: %s. ' 'Found next arguments: %s.' % (missing_args, known_args)) + def _same_value(a, b): + """A comparison operation that works for multiple object types. + + Returns True for two empty lists, two numeric values with the + same value, etc. + + Returns False for (pd.DataFrame, None), and other pairs which + should not be considered equivalent. + + Args: + a: value one of the comparison. + b: value two of the comparison. + + Returns: + A boolean indicating whether the two inputs are the same value + for the purposes of deprecation. + """ + if a is b: + return True + try: + equality = a == b + if isinstance(equality, bool): + return equality + except TypeError: + return False + return False + @functools.wraps(func) def new_func(*args, **kwargs): """Deprecation wrapper.""" @@ -232,7 +259,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): for arg_name, spec in iter(deprecated_positions.items()): if (spec.position < len(args) and not (spec.has_ok_value and - named_args[arg_name] == spec.ok_value)): + _same_value(named_args[arg_name], spec.ok_value))): invalid_args.append(arg_name) if is_varargs_deprecated and len(args) > len(arg_spec.args): invalid_args.append(arg_spec.varargs) @@ -241,8 +268,8 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples): for arg_name in deprecated_arg_names: if (arg_name in kwargs and not (deprecated_positions[arg_name].has_ok_value and - (named_args[arg_name] == - deprecated_positions[arg_name].ok_value))): + _same_value(named_args[arg_name], + deprecated_positions[arg_name].ok_value))): invalid_args.append(arg_name) for arg_name in invalid_args: logging.warning( |