aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-05-11 10:58:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 11:01:30 -0700
commit9c82788d12037fc10b60b06092e94d513eb4aa14 (patch)
treee578a9cb80ee6e6a6fb6ff1e19cfc8c7dc0f522e
parent1aa40a1ce7869b6557049bcc623dad452a69ef6c (diff)
Move fn_args utility into core TensorFlow from Estimator.
Working on untangling TF/Estimator deps. Some core TF code depends on Estimator by using the fn_args utility function within Estimator. PiperOrigin-RevId: 196277612
-rw-r--r--tensorflow/contrib/eager/python/network.py6
-rw-r--r--tensorflow/contrib/estimator/BUILD2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/logit_fns.py4
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/experiment.py4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py8
-rw-r--r--tensorflow/python/BUILD10
-rw-r--r--tensorflow/python/estimator/BUILD12
-rw-r--r--tensorflow/python/estimator/canned/head.py6
-rw-r--r--tensorflow/python/estimator/estimator.py8
-rw-r--r--tensorflow/python/estimator/estimator_test.py6
-rw-r--r--tensorflow/python/estimator/run_config.py4
-rw-r--r--tensorflow/python/estimator/util.py40
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py7
-rw-r--r--tensorflow/python/layers/base.py4
-rw-r--r--tensorflow/python/ops/variable_scope.py4
-rw-r--r--tensorflow/python/training/monitored_session.py4
-rw-r--r--tensorflow/python/util/function_utils.py57
-rw-r--r--tensorflow/python/util/function_utils_test.py (renamed from tensorflow/python/estimator/util_test.py)18
20 files changed, 119 insertions, 95 deletions
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index 44828bea50..9af50ee146 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -23,7 +23,6 @@ import os
import weakref
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer
from tensorflow.python.layers import base
@@ -33,6 +32,7 @@ from tensorflow.python.training import checkpoint_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
# pylint: disable=protected-access
# Explanation for protected-access disable: Network has lots of same-class and
@@ -545,10 +545,10 @@ class Sequential(Network):
def add(self, layer_func):
if isinstance(layer_func, base.Layer):
- args = estimator_util.fn_args(layer_func.call)
+ args = function_utils.fn_args(layer_func.call)
self.track_layer(layer_func)
elif callable(layer_func):
- args = estimator_util.fn_args(layer_func)
+ args = function_utils.fn_args(layer_func)
else:
raise TypeError(
"Sequential.add() takes only tf.layers.Layer objects or callables; "
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 53bbafd4a7..df08dc2be6 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -366,9 +366,9 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:util",
"//tensorflow/python/estimator:dnn",
"//tensorflow/python/estimator:linear",
- "//tensorflow/python/estimator:util",
],
)
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py
index 201699ed77..bf08be09e7 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders.py
@@ -22,12 +22,12 @@ import six
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.util import function_utils
_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])
@@ -330,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer):
def _verify_metric_fn_args(metric_fn):
- args = set(estimator_util.fn_args(metric_fn))
+ args = set(function_utils.fn_args(metric_fn))
invalid_args = list(args - _VALID_METRIC_FN_ARGS)
if invalid_args:
raise ValueError('metric_fn (%s) has following not expected args: %s' %
@@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn):
def _call_metric_fn(metric_fn, features, labels, predictions, config):
"""Calls metric fn with proper arguments."""
- metric_fn_args = estimator_util.fn_args(metric_fn)
+ metric_fn_args = function_utils.fn_args(metric_fn)
kwargs = {}
if 'features' in metric_fn_args:
kwargs['features'] = features
diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py
index 09c2862ccd..c8b0dd6297 100644
--- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py
+++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py
@@ -41,10 +41,10 @@ from __future__ import print_function
import six
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import dnn as dnn_core
from tensorflow.python.estimator.canned import linear as linear_core
from tensorflow.python.framework import ops
+from tensorflow.python.util import function_utils
# pylint: disable=protected-access
dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder
@@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config):
ValueError: if logit_fn does not return a Tensor or a dictionary mapping
strings to Tensors.
"""
- logit_fn_args = util.fn_args(logit_fn)
+ logit_fn_args = function_utils.fn_args(logit_fn)
kwargs = {}
if 'mode' in logit_fn_args:
kwargs['mode'] = mode
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index f8564446e5..cda23aa437 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -32,7 +32,6 @@ import six
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export_output as export_output_lib
from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import ops as ops_lib
@@ -48,6 +47,7 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_setter as device_setter_lib
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.util import deprecation
+from tensorflow.python.util import function_utils
@deprecation.deprecated(
@@ -521,7 +521,7 @@ def _get_loss_towers(model_fn,
"""Replicate the loss computation across devices."""
tower_specs = []
- model_fn_args = util.fn_args(model_fn)
+ model_fn_args = function_utils.fn_args(model_fn)
optional_params = {}
if 'params' in model_fn_args:
optional_params['params'] = copy.deepcopy(params)
diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py
index dfc6a393d0..541da90617 100644
--- a/tensorflow/contrib/learn/python/learn/experiment.py
+++ b/tensorflow/contrib/learn/python/learn/experiment.py
@@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import saver
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
__all__ = ["Experiment"]
def _get_standardized_predicate_fn(predicate_fn):
- pred_fn_args = estimator_util.fn_args(predicate_fn)
+ pred_fn_args = function_utils.fn_args(predicate_fn)
if "checkpoint_path" not in pred_fn_args:
# pylint: disable=unused-argument
def _pred_fn_wrapper(eval_results, checkpoint_path):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index afc8c7d5cc..1bf2fc5dea 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -46,7 +46,6 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator import util
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -68,6 +67,7 @@ from tensorflow.python.training import evaluation
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -1269,7 +1269,7 @@ class _ModelFnWrapper(object):
def _call_model_fn(self, features, labels, is_export_mode=False):
"""Calls the model_fn with required parameters."""
- model_fn_args = util.fn_args(self._model_fn)
+ model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
# Makes deep copy with `config` and params` in case user mutates them.
@@ -1361,7 +1361,7 @@ class _OutfeedHostCall(object):
if isinstance(host_call[1], (tuple, list)):
fullargspec = tf_inspect.getfullargspec(host_call[0])
- fn_args = util.fn_args(host_call[0])
+ fn_args = function_utils.fn_args(host_call[0])
# wrapped_hostcall_with_global_step uses varargs, so we allow that.
if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
raise RuntimeError(
@@ -1938,7 +1938,7 @@ class TPUEstimator(estimator_lib.Estimator):
Raises:
ValueError: if input_fn takes invalid arguments or does not have `params`.
"""
- input_fn_args = util.fn_args(input_fn)
+ input_fn_args = function_utils.fn_args(input_fn)
config = self.config # a deep copy.
kwargs = {}
if 'params' in input_fn_args:
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 8b904a16c7..cc96d5aee5 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3250,6 +3250,16 @@ py_test(
)
py_test(
+ name = "function_utils_test",
+ srcs = ["util/function_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client_testlib",
+ ":util",
+ ],
+)
+
+py_test(
name = "tf_contextlib_test",
size = "small",
srcs = ["util/tf_contextlib_test.py"],
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 2d9a084bc6..a498e85572 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -445,16 +445,6 @@ py_library(
],
)
-py_test(
- name = "util_test",
- srcs = ["util_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":util",
- "//tensorflow/python:client_testlib",
- ],
-)
-
py_library(
name = "estimator",
srcs = [
@@ -645,7 +635,6 @@ py_library(
":metric_keys",
":model_fn",
":prediction_keys",
- ":util",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
@@ -659,6 +648,7 @@ py_library(
"//tensorflow/python:string_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python:weights_broadcast_ops",
"//tensorflow/python/feature_column",
"//tensorflow/python/ops/losses",
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index 232637314d..dcf8b15dad 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -24,7 +24,6 @@ import collections
import six
from tensorflow.python.estimator import model_fn
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import prediction_keys
from tensorflow.python.estimator.export import export_output
@@ -46,6 +45,7 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
+from tensorflow.python.util import function_utils
_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -461,7 +461,7 @@ def _validate_loss_fn_args(loss_fn):
Raises:
ValueError: If the signature is unexpected.
"""
- loss_fn_args = util.fn_args(loss_fn)
+ loss_fn_args = function_utils.fn_args(loss_fn)
for required_arg in ['labels', 'logits']:
if required_arg not in loss_fn_args:
raise ValueError(
@@ -484,7 +484,7 @@ def _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):
Returns:
Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].
"""
- loss_fn_args = util.fn_args(loss_fn)
+ loss_fn_args = function_utils.fn_args(loss_fn)
kwargs = {}
if 'features' in loss_fn_args:
kwargs['features'] = features
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 9cfc680789..5fdda0427f 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -36,7 +36,6 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export as export_helpers
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import errors
@@ -63,6 +62,7 @@ from tensorflow.python.training import training_util
from tensorflow.python.training import warm_starting_util
from tensorflow.python.util import compat
from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -1052,7 +1052,7 @@ class Estimator(object):
Raises:
ValueError: if input_fn takes invalid arguments.
"""
- input_fn_args = util.fn_args(input_fn)
+ input_fn_args = function_utils.fn_args(input_fn)
kwargs = {}
if 'mode' in input_fn_args:
kwargs['mode'] = mode
@@ -1078,7 +1078,7 @@ class Estimator(object):
Raises:
ValueError: if model_fn returns invalid objects.
"""
- model_fn_args = util.fn_args(self._model_fn)
+ model_fn_args = function_utils.fn_args(self._model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
@@ -1483,7 +1483,7 @@ def _get_replica_device_setter(config):
def _verify_model_fn_args(model_fn, params):
"""Verifies model fn arguments."""
- args = set(util.fn_args(model_fn))
+ args = set(function_utils.fn_args(model_fn))
if 'features' not in args:
raise ValueError('model_fn (%s) must include features argument.' % model_fn)
if params is not None and 'params' not in args:
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 0f268f5df9..1b70189948 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -33,7 +33,6 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
-from tensorflow.python.estimator import util
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.export import export_output
from tensorflow.python.estimator.inputs import numpy_io
@@ -72,6 +71,7 @@ from tensorflow.python.training import saver_test_utils
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training
from tensorflow.python.util import compat
+from tensorflow.python.util import function_utils
_TMP_DIR = '/tmp'
_ANOTHER_TMP_DIR = '/another_tmp'
@@ -332,7 +332,7 @@ class EstimatorConstructorTest(test.TestCase):
_, _, _, _, _ = features, labels, mode, config, params
est = estimator.Estimator(model_fn=model_fn)
- model_fn_args = util.fn_args(est.model_fn)
+ model_fn_args = function_utils.fn_args(est.model_fn)
self.assertEqual(
set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
@@ -342,7 +342,7 @@ class EstimatorConstructorTest(test.TestCase):
_, _ = features, labels
est = estimator.Estimator(model_fn=model_fn)
- model_fn_args = util.fn_args(est.model_fn)
+ model_fn_args = function_utils.fn_args(est.model_fn)
self.assertEqual(
set(['features', 'labels', 'mode', 'config']), set(model_fn_args))
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 8162b249f1..c7707be839 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -27,8 +27,8 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
-from tensorflow.python.estimator import util
from tensorflow.python.util import compat_internal
+from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
@@ -283,7 +283,7 @@ def _validate_properties(run_config):
message='tf_random_seed must be integer.')
_validate('device_fn', lambda device_fn: six.callable(device_fn) and
- set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
+ set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
message='device_fn must be callable with exactly'
' one argument "op".')
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index bb4bdd3fdf..e4e1d37f74 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -13,55 +13,21 @@
# limitations under the License.
# ==============================================================================
-"""Utility to retrieve function args."""
+"""Utilities for Estimators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import functools
import os
import time
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
-from tensorflow.python.util import tf_decorator
-from tensorflow.python.util import tf_inspect
-
-
-def _is_bounded_method(fn):
- _, fn = tf_decorator.unwrap(fn)
- return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
-
-
-def _is_callable_object(obj):
- return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
-
-
-def fn_args(fn):
- """Get argument names for function-like object.
-
- Args:
- fn: Function, or function-like object (e.g., result of `functools.partial`).
-
- Returns:
- `tuple` of string argument names.
-
- Raises:
- ValueError: if partial function has positionally bound arguments
- """
- if isinstance(fn, functools.partial):
- args = fn_args(fn.func)
- args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
- else:
- if _is_callable_object(fn):
- fn = fn.__call__
- args = tf_inspect.getfullargspec(fn).args
- if _is_bounded_method(fn):
- args.remove('self')
- return tuple(args)
+from tensorflow.python.util import function_utils
+fn_args = function_utils.fn_args
# When we create a timestamped directory, there is a small chance that the
# directory already exists because another process is also creating these
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 16ee2952b2..72ab77fbbd 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -25,7 +25,7 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
+from tensorflow.python.estimator import util as function_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -44,6 +44,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training import checkpointable
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -146,7 +147,7 @@ class Layer(checkpointable.CheckpointableBase):
# return tensors. When using graph execution, _losses is a list of ops.
self._losses = []
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
self._uses_inputs_arg = True
@@ -644,7 +645,7 @@ class Layer(checkpointable.CheckpointableBase):
self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 64db49c900..2040e0081e 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -20,12 +20,12 @@ from __future__ import print_function
import copy
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras._impl.keras.engine import base_layer
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
+from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -308,7 +308,7 @@ class Layer(base_layer.Layer):
try:
call_has_scope_arg = self._call_has_scope_arg
except AttributeError:
- self._call_fn_args = estimator_util.fn_args(self.call)
+ self._call_fn_args = function_utils.fn_args(self.call)
self._call_has_scope_arg = 'scope' in self._call_fn_args
call_has_scope_arg = self._call_has_scope_arg
if call_has_scope_arg:
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index adb0f59948..f5970fdbb2 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -32,7 +32,6 @@ from six import iteritems
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
-from tensorflow.python.estimator import util as estimator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -41,6 +40,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
@@ -422,7 +422,7 @@ class _VariableStore(object):
"use_resource": use_resource,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
- if "constraint" in estimator_util.fn_args(custom_getter):
+ if "constraint" in function_utils.fn_args(custom_getter):
custom_getter_kwargs["constraint"] = constraint
return custom_getter(**custom_getter_kwargs)
else:
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index f584a009d9..fece3370f3 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -25,7 +25,6 @@ import sys
import six
from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.estimator import util
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -41,6 +40,7 @@ from tensorflow.python.training import queue_runner
from tensorflow.python.training import saver as training_saver
from tensorflow.python.training import session_manager as sm
from tensorflow.python.training import session_run_hook
+from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
@@ -620,7 +620,7 @@ class _MonitoredSession(object):
`step_context`. It may also optionally have `self` for cases when it
belongs to an object.
"""
- step_fn_arguments = util.fn_args(step_fn)
+ step_fn_arguments = function_utils.fn_args(step_fn)
if step_fn_arguments != ('step_context',) and step_fn_arguments != (
'self',
'step_context',
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
new file mode 100644
index 0000000000..7bbbde3cd2
--- /dev/null
+++ b/tensorflow/python/util/function_utils.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility to retrieve function args."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
+
+
+def _is_bounded_method(fn):
+ _, fn = tf_decorator.unwrap(fn)
+ return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
+
+
+def _is_callable_object(obj):
+ return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__)
+
+
+def fn_args(fn):
+ """Get argument names for function-like object.
+
+ Args:
+ fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+ Returns:
+ `tuple` of string argument names.
+
+ Raises:
+ ValueError: if partial function has positionally bound arguments
+ """
+ if isinstance(fn, functools.partial):
+ args = fn_args(fn.func)
+ args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
+ else:
+ if _is_callable_object(fn):
+ fn = fn.__call__
+ args = tf_inspect.getfullargspec(fn).args
+ if _is_bounded_method(fn):
+ args.remove('self')
+ return tuple(args)
diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/util/function_utils_test.py
index 4b2c8d7637..e78cf6a5b0 100644
--- a/tensorflow/python/estimator/util_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import functools
-from tensorflow.python.estimator import util
from tensorflow.python.platform import test
+from tensorflow.python.util import function_utils
class FnArgsTest(test.TestCase):
@@ -29,7 +29,7 @@ class FnArgsTest(test.TestCase):
def test_simple_function(self):
def fn(a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(fn))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
def test_callable(self):
@@ -38,7 +38,7 @@ class FnArgsTest(test.TestCase):
def __call__(self, a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(Foo()))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
def test_bounded_method(self):
@@ -47,7 +47,7 @@ class FnArgsTest(test.TestCase):
def bar(self, a, b):
return a + b
- self.assertEqual(('a', 'b'), util.fn_args(Foo().bar))
+ self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
def test_partial_function(self):
expected_test_arg = 123
@@ -59,7 +59,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, test_arg=123)
- self.assertEqual(('a',), util.fn_args(wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
def test_partial_function_with_positional_args(self):
expected_test_arg = 123
@@ -71,7 +71,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, 123)
- self.assertEqual(('a',), util.fn_args(wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
self.assertEqual(3, wrapped_fn(3))
self.assertEqual(3, wrapped_fn(a=3))
@@ -88,7 +88,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
def test_double_partial_with_positional_args_in_outer_layer(self):
expected_test_arg1 = 123
@@ -102,7 +102,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, test_arg2=456)
double_wrapped_fn = functools.partial(wrapped_fn, 123)
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))
@@ -119,7 +119,7 @@ class FnArgsTest(test.TestCase):
wrapped_fn = functools.partial(fn, 123) # binds to test_arg1
double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2
- self.assertEqual(('a',), util.fn_args(double_wrapped_fn))
+ self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
self.assertEqual(3, double_wrapped_fn(3))
self.assertEqual(3, double_wrapped_fn(a=3))