aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-15 10:57:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 11:05:27 -0800
commit5462618b916cad9aec46a9d1fd04ef353e9696b9 (patch)
tree32225e17aa28db659d679098d7929c3d1034fca0
parentb8e47b00507d2e9783821c13063e4d94c5cd9809 (diff)
Mutes warnings for some deprecated arguments with default values. Fixes a bug.
Users were getting log warnings even when using the API correctly, because of argument passthrough within tf.learn. This change mutes warnings for calls to fit, estimate, and predict if the deprecated arguments receive default values. Also fixes a bug in the deprecation tool where "==" was used instead of "is" to compare objects, which was at times yeilding a non-boolean. Change: 142162193
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py13
-rw-r--r--tensorflow/python/util/deprecation.py33
2 files changed, 38 insertions, 8 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 91d900395b..2d40caa656 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -341,7 +341,8 @@ class BaseEstimator(
return copy.deepcopy(self._config)
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('y', None), ('batch_size', None)
)
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
@@ -367,7 +368,8 @@ class BaseEstimator(
return self
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('y', None), ('batch_size', None)
)
def partial_fit(
self, x=None, y=None, input_fn=None, steps=1, batch_size=None,
@@ -411,7 +413,8 @@ class BaseEstimator(
batch_size=batch_size, monitors=monitors)
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('y', None), ('batch_size', None)
)
def evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
@@ -442,8 +445,8 @@ class BaseEstimator(
return eval_results
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'batch_size',
- 'as_iterable'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('batch_size', None), ('as_iterable', True)
)
def predict(
self, x=None, input_fn=None, batch_size=None, outputs=None,
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 32d2678ce7..c4afed649e 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -224,6 +224,33 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
'in the function signature: %s. '
'Found next arguments: %s.' % (missing_args, known_args))
+ def _same_value(a, b):
+ """A comparison operation that works for multiple object types.
+
+ Returns True for two empty lists, two numeric values with the
+ same value, etc.
+
+ Returns False for (pd.DataFrame, None), and other pairs which
+ should not be considered equivalent.
+
+ Args:
+ a: value one of the comparison.
+ b: value two of the comparison.
+
+ Returns:
+ A boolean indicating whether the two inputs are the same value
+ for the purposes of deprecation.
+ """
+ if a is b:
+ return True
+ try:
+ equality = a == b
+ if isinstance(equality, bool):
+ return equality
+ except TypeError:
+ return False
+ return False
+
@functools.wraps(func)
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
@@ -232,7 +259,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
for arg_name, spec in iter(deprecated_positions.items()):
if (spec.position < len(args) and
not (spec.has_ok_value and
- named_args[arg_name] == spec.ok_value)):
+ _same_value(named_args[arg_name], spec.ok_value))):
invalid_args.append(arg_name)
if is_varargs_deprecated and len(args) > len(arg_spec.args):
invalid_args.append(arg_spec.varargs)
@@ -241,8 +268,8 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
for arg_name in deprecated_arg_names:
if (arg_name in kwargs and
not (deprecated_positions[arg_name].has_ok_value and
- (named_args[arg_name] ==
- deprecated_positions[arg_name].ok_value))):
+ _same_value(named_args[arg_name],
+ deprecated_positions[arg_name].ok_value))):
invalid_args.append(arg_name)
for arg_name in invalid_args:
logging.warning(