aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py13
-rw-r--r--tensorflow/python/util/deprecation.py33
2 files changed, 38 insertions, 8 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 91d900395b..2d40caa656 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -341,7 +341,8 @@ class BaseEstimator(
return copy.deepcopy(self._config)
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('y', None), ('batch_size', None)
)
def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None,
monitors=None, max_steps=None):
@@ -367,7 +368,8 @@ class BaseEstimator(
return self
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('y', None), ('batch_size', None)
)
def partial_fit(
self, x=None, y=None, input_fn=None, steps=1, batch_size=None,
@@ -411,7 +413,8 @@ class BaseEstimator(
batch_size=batch_size, monitors=monitors)
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'y', 'batch_size'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('y', None), ('batch_size', None)
)
def evaluate(
self, x=None, y=None, input_fn=None, feed_fn=None, batch_size=None,
@@ -442,8 +445,8 @@ class BaseEstimator(
return eval_results
@deprecated_args(
- SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 'x', 'batch_size',
- 'as_iterable'
+ SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None),
+ ('batch_size', None), ('as_iterable', True)
)
def predict(
self, x=None, input_fn=None, batch_size=None, outputs=None,
diff --git a/tensorflow/python/util/deprecation.py b/tensorflow/python/util/deprecation.py
index 32d2678ce7..c4afed649e 100644
--- a/tensorflow/python/util/deprecation.py
+++ b/tensorflow/python/util/deprecation.py
@@ -224,6 +224,33 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
'in the function signature: %s. '
'Found next arguments: %s.' % (missing_args, known_args))
+ def _same_value(a, b):
+ """A comparison operation that works for multiple object types.
+
+ Returns True for two empty lists, two numeric values with the
+ same value, etc.
+
+ Returns False for (pd.DataFrame, None), and other pairs which
+ should not be considered equivalent.
+
+ Args:
+ a: value one of the comparison.
+ b: value two of the comparison.
+
+ Returns:
+ A boolean indicating whether the two inputs are the same value
+ for the purposes of deprecation.
+ """
+ if a is b:
+ return True
+ try:
+ equality = a == b
+ if isinstance(equality, bool):
+ return equality
+ except TypeError:
+ return False
+ return False
+
@functools.wraps(func)
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
@@ -232,7 +259,7 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
for arg_name, spec in iter(deprecated_positions.items()):
if (spec.position < len(args) and
not (spec.has_ok_value and
- named_args[arg_name] == spec.ok_value)):
+ _same_value(named_args[arg_name], spec.ok_value))):
invalid_args.append(arg_name)
if is_varargs_deprecated and len(args) > len(arg_spec.args):
invalid_args.append(arg_spec.varargs)
@@ -241,8 +268,8 @@ def deprecated_args(date, instructions, *deprecated_arg_names_or_tuples):
for arg_name in deprecated_arg_names:
if (arg_name in kwargs and
not (deprecated_positions[arg_name].has_ok_value and
- (named_args[arg_name] ==
- deprecated_positions[arg_name].ok_value))):
+ _same_value(named_args[arg_name],
+ deprecated_positions[arg_name].ok_value))):
invalid_args.append(arg_name)
for arg_name in invalid_args:
logging.warning(