From 9c82788d12037fc10b60b06092e94d513eb4aa14 Mon Sep 17 00:00:00 2001 From: Michael Case Date: Fri, 11 May 2018 10:58:17 -0700 Subject: Move fn_args utility into core TensorFlow from Estimator. Working on untangling TF/Estimator deps. Some core TF code depends on Estimator by using the fn_args utility function within Estimator. PiperOrigin-RevId: 196277612 --- tensorflow/contrib/eager/python/network.py | 6 +- tensorflow/contrib/estimator/BUILD | 2 +- .../estimator/python/estimator/extenders.py | 6 +- .../estimator/python/estimator/logit_fns.py | 4 +- .../python/estimator/replicate_model_fn.py | 4 +- .../contrib/learn/python/learn/experiment.py | 4 +- tensorflow/contrib/tpu/python/tpu/tpu_estimator.py | 8 +- tensorflow/python/BUILD | 10 ++ tensorflow/python/estimator/BUILD | 12 +- tensorflow/python/estimator/canned/head.py | 6 +- tensorflow/python/estimator/estimator.py | 8 +- tensorflow/python/estimator/estimator_test.py | 6 +- tensorflow/python/estimator/run_config.py | 4 +- tensorflow/python/estimator/util.py | 40 +------ tensorflow/python/estimator/util_test.py | 128 --------------------- .../python/keras/_impl/keras/engine/base_layer.py | 7 +- tensorflow/python/layers/base.py | 4 +- tensorflow/python/ops/variable_scope.py | 4 +- tensorflow/python/training/monitored_session.py | 4 +- tensorflow/python/util/function_utils.py | 57 +++++++++ tensorflow/python/util/function_utils_test.py | 128 +++++++++++++++++++++ 21 files changed, 238 insertions(+), 214 deletions(-) delete mode 100644 tensorflow/python/estimator/util_test.py create mode 100644 tensorflow/python/util/function_utils.py create mode 100644 tensorflow/python/util/function_utils_test.py diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 44828bea50..9af50ee146 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -23,7 +23,6 @@ import os import weakref from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.keras._impl.keras.engine import base_layer as keras_base_layer from tensorflow.python.layers import base @@ -33,6 +32,7 @@ from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils # pylint: disable=protected-access # Explanation for protected-access disable: Network has lots of same-class and @@ -545,10 +545,10 @@ class Sequential(Network): def add(self, layer_func): if isinstance(layer_func, base.Layer): - args = estimator_util.fn_args(layer_func.call) + args = function_utils.fn_args(layer_func.call) self.track_layer(layer_func) elif callable(layer_func): - args = estimator_util.fn_args(layer_func) + args = function_utils.fn_args(layer_func) else: raise TypeError( "Sequential.add() takes only tf.layers.Layer objects or callables; " diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 53bbafd4a7..df08dc2be6 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -366,9 +366,9 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework_ops", + "//tensorflow/python:util", "//tensorflow/python/estimator:dnn", "//tensorflow/python/estimator:linear", - "//tensorflow/python/estimator:util", ], ) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 201699ed77..bf08be09e7 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -22,12 +22,12 @@ import six from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.util import function_utils _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @@ -330,7 +330,7 @@ class _TransformGradients(optimizer_lib.Optimizer): def _verify_metric_fn_args(metric_fn): - args = set(estimator_util.fn_args(metric_fn)) + args = set(function_utils.fn_args(metric_fn)) invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % @@ -339,7 +339,7 @@ def _verify_metric_fn_args(metric_fn): def _call_metric_fn(metric_fn, features, labels, predictions, config): """Calls metric fn with proper arguments.""" - metric_fn_args = estimator_util.fn_args(metric_fn) + metric_fn_args = function_utils.fn_args(metric_fn) kwargs = {} if 'features' in metric_fn_args: kwargs['features'] = features diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 09c2862ccd..c8b0dd6297 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -41,10 +41,10 @@ from __future__ import print_function import six -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core from tensorflow.python.framework import ops +from tensorflow.python.util import function_utils # pylint: disable=protected-access dnn_logit_fn_builder = dnn_core._dnn_logit_fn_builder @@ -72,7 +72,7 @@ def call_logit_fn(logit_fn, features, mode, params, config): ValueError: if logit_fn does not return a Tensor or a dictionary mapping strings to Tensors. """ - logit_fn_args = util.fn_args(logit_fn) + logit_fn_args = function_utils.fn_args(logit_fn) kwargs = {} if 'mode' in logit_fn_args: kwargs['mode'] = mode diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index f8564446e5..cda23aa437 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -32,7 +32,6 @@ import six from tensorflow.core.framework import node_def_pb2 from tensorflow.python.client import device_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export_output as export_output_lib from tensorflow.python.framework import device as framework_device from tensorflow.python.framework import ops as ops_lib @@ -48,6 +47,7 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_setter as device_setter_lib from tensorflow.python.training import optimizer as optimizer_lib from tensorflow.python.util import deprecation +from tensorflow.python.util import function_utils @deprecation.deprecated( @@ -521,7 +521,7 @@ def _get_loss_towers(model_fn, """Replicate the loss computation across devices.""" tower_specs = [] - model_fn_args = util.fn_args(model_fn) + model_fn_args = function_utils.fn_args(model_fn) optional_params = {} if 'params' in model_fn_args: optional_params['params'] = copy.deepcopy(params) diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index dfc6a393d0..541da90617 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -38,19 +38,19 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import function_utils __all__ = ["Experiment"] def _get_standardized_predicate_fn(predicate_fn): - pred_fn_args = estimator_util.fn_args(predicate_fn) + pred_fn_args = function_utils.fn_args(predicate_fn) if "checkpoint_path" not in pred_fn_args: # pylint: disable=unused-argument def _pred_fn_wrapper(eval_results, checkpoint_path): diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index afc8c7d5cc..1bf2fc5dea 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -46,7 +46,6 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib -from tensorflow.python.estimator import util from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -68,6 +67,7 @@ from tensorflow.python.training import evaluation from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.training import training_util +from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect @@ -1269,7 +1269,7 @@ class _ModelFnWrapper(object): def _call_model_fn(self, features, labels, is_export_mode=False): """Calls the model_fn with required parameters.""" - model_fn_args = util.fn_args(self._model_fn) + model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} # Makes deep copy with `config` and params` in case user mutates them. @@ -1361,7 +1361,7 @@ class _OutfeedHostCall(object): if isinstance(host_call[1], (tuple, list)): fullargspec = tf_inspect.getfullargspec(host_call[0]) - fn_args = util.fn_args(host_call[0]) + fn_args = function_utils.fn_args(host_call[0]) # wrapped_hostcall_with_global_step uses varargs, so we allow that. if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): raise RuntimeError( @@ -1938,7 +1938,7 @@ class TPUEstimator(estimator_lib.Estimator): Raises: ValueError: if input_fn takes invalid arguments or does not have `params`. """ - input_fn_args = util.fn_args(input_fn) + input_fn_args = function_utils.fn_args(input_fn) config = self.config # a deep copy. kwargs = {} if 'params' in input_fn_args: diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 8b904a16c7..cc96d5aee5 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3249,6 +3249,16 @@ py_test( ], ) +py_test( + name = "function_utils_test", + srcs = ["util/function_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":util", + ], +) + py_test( name = "tf_contextlib_test", size = "small", diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 2d9a084bc6..a498e85572 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -445,16 +445,6 @@ py_library( ], ) -py_test( - name = "util_test", - srcs = ["util_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":util", - "//tensorflow/python:client_testlib", - ], -) - py_library( name = "estimator", srcs = [ @@ -645,7 +635,6 @@ py_library( ":metric_keys", ":model_fn", ":prediction_keys", - ":util", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", @@ -659,6 +648,7 @@ py_library( "//tensorflow/python:string_ops", "//tensorflow/python:summary", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:weights_broadcast_ops", "//tensorflow/python/feature_column", "//tensorflow/python/ops/losses", diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index 232637314d..dcf8b15dad 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -24,7 +24,6 @@ import collections import six from tensorflow.python.estimator import model_fn -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys from tensorflow.python.estimator.export import export_output @@ -46,6 +45,7 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary from tensorflow.python.training import training_util +from tensorflow.python.util import function_utils _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY @@ -461,7 +461,7 @@ def _validate_loss_fn_args(loss_fn): Raises: ValueError: If the signature is unexpected. """ - loss_fn_args = util.fn_args(loss_fn) + loss_fn_args = function_utils.fn_args(loss_fn) for required_arg in ['labels', 'logits']: if required_arg not in loss_fn_args: raise ValueError( @@ -484,7 +484,7 @@ def _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1): Returns: Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim]. """ - loss_fn_args = util.fn_args(loss_fn) + loss_fn_args = function_utils.fn_args(loss_fn) kwargs = {} if 'features' in loss_fn_args: kwargs['features'] = features diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 9cfc680789..5fdda0427f 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -36,7 +36,6 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export as export_helpers from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import errors @@ -63,6 +62,7 @@ from tensorflow.python.training import training_util from tensorflow.python.training import warm_starting_util from tensorflow.python.util import compat from tensorflow.python.util import compat_internal +from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -1052,7 +1052,7 @@ class Estimator(object): Raises: ValueError: if input_fn takes invalid arguments. """ - input_fn_args = util.fn_args(input_fn) + input_fn_args = function_utils.fn_args(input_fn) kwargs = {} if 'mode' in input_fn_args: kwargs['mode'] = mode @@ -1078,7 +1078,7 @@ class Estimator(object): Raises: ValueError: if model_fn returns invalid objects. """ - model_fn_args = util.fn_args(self._model_fn) + model_fn_args = function_utils.fn_args(self._model_fn) kwargs = {} if 'labels' in model_fn_args: kwargs['labels'] = labels @@ -1483,7 +1483,7 @@ def _get_replica_device_setter(config): def _verify_model_fn_args(model_fn, params): """Verifies model fn arguments.""" - args = set(util.fn_args(model_fn)) + args = set(function_utils.fn_args(model_fn)) if 'features' not in args: raise ValueError('model_fn (%s) must include features argument.' % model_fn) if params is not None and 'params' not in args: diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 0f268f5df9..1b70189948 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -33,7 +33,6 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import run_config -from tensorflow.python.estimator import util from tensorflow.python.estimator.export import export from tensorflow.python.estimator.export import export_output from tensorflow.python.estimator.inputs import numpy_io @@ -72,6 +71,7 @@ from tensorflow.python.training import saver_test_utils from tensorflow.python.training import session_run_hook from tensorflow.python.training import training from tensorflow.python.util import compat +from tensorflow.python.util import function_utils _TMP_DIR = '/tmp' _ANOTHER_TMP_DIR = '/another_tmp' @@ -332,7 +332,7 @@ class EstimatorConstructorTest(test.TestCase): _, _, _, _, _ = features, labels, mode, config, params est = estimator.Estimator(model_fn=model_fn) - model_fn_args = util.fn_args(est.model_fn) + model_fn_args = function_utils.fn_args(est.model_fn) self.assertEqual( set(['features', 'labels', 'mode', 'config']), set(model_fn_args)) @@ -342,7 +342,7 @@ class EstimatorConstructorTest(test.TestCase): _, _ = features, labels est = estimator.Estimator(model_fn=model_fn) - model_fn_args = util.fn_args(est.model_fn) + model_fn_args = function_utils.fn_args(est.model_fn) self.assertEqual( set(['features', 'labels', 'mode', 'config']), set(model_fn_args)) diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index 8162b249f1..c7707be839 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -27,8 +27,8 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib -from tensorflow.python.estimator import util from tensorflow.python.util import compat_internal +from tensorflow.python.util import function_utils from tensorflow.python.util.tf_export import tf_export @@ -283,7 +283,7 @@ def _validate_properties(run_config): message='tf_random_seed must be integer.') _validate('device_fn', lambda device_fn: six.callable(device_fn) and - set(util.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS, + set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS, message='device_fn must be callable with exactly' ' one argument "op".') diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index bb4bdd3fdf..e4e1d37f74 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -13,55 +13,21 @@ # limitations under the License. # ============================================================================== -"""Utility to retrieve function args.""" +"""Utilities for Estimators.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import functools import os import time from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat -from tensorflow.python.util import tf_decorator -from tensorflow.python.util import tf_inspect - - -def _is_bounded_method(fn): - _, fn = tf_decorator.unwrap(fn) - return tf_inspect.ismethod(fn) and (fn.__self__ is not None) - - -def _is_callable_object(obj): - return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__) - - -def fn_args(fn): - """Get argument names for function-like object. - - Args: - fn: Function, or function-like object (e.g., result of `functools.partial`). - - Returns: - `tuple` of string argument names. - - Raises: - ValueError: if partial function has positionally bound arguments - """ - if isinstance(fn, functools.partial): - args = fn_args(fn.func) - args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] - else: - if _is_callable_object(fn): - fn = fn.__call__ - args = tf_inspect.getfullargspec(fn).args - if _is_bounded_method(fn): - args.remove('self') - return tuple(args) +from tensorflow.python.util import function_utils +fn_args = function_utils.fn_args # When we create a timestamped directory, there is a small chance that the # directory already exists because another process is also creating these diff --git a/tensorflow/python/estimator/util_test.py b/tensorflow/python/estimator/util_test.py deleted file mode 100644 index 4b2c8d7637..0000000000 --- a/tensorflow/python/estimator/util_test.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Estimator related util.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -from tensorflow.python.estimator import util -from tensorflow.python.platform import test - - -class FnArgsTest(test.TestCase): - - def test_simple_function(self): - def fn(a, b): - return a + b - self.assertEqual(('a', 'b'), util.fn_args(fn)) - - def test_callable(self): - - class Foo(object): - - def __call__(self, a, b): - return a + b - - self.assertEqual(('a', 'b'), util.fn_args(Foo())) - - def test_bounded_method(self): - - class Foo(object): - - def bar(self, a, b): - return a + b - - self.assertEqual(('a', 'b'), util.fn_args(Foo().bar)) - - def test_partial_function(self): - expected_test_arg = 123 - - def fn(a, test_arg): - if test_arg != expected_test_arg: - return ValueError('partial fn does not work correctly') - return a - - wrapped_fn = functools.partial(fn, test_arg=123) - - self.assertEqual(('a',), util.fn_args(wrapped_fn)) - - def test_partial_function_with_positional_args(self): - expected_test_arg = 123 - - def fn(test_arg, a): - if test_arg != expected_test_arg: - return ValueError('partial fn does not work correctly') - return a - - wrapped_fn = functools.partial(fn, 123) - - self.assertEqual(('a',), util.fn_args(wrapped_fn)) - - self.assertEqual(3, wrapped_fn(3)) - self.assertEqual(3, wrapped_fn(a=3)) - - def test_double_partial(self): - expected_test_arg1 = 123 - expected_test_arg2 = 456 - - def fn(a, test_arg1, test_arg2): - if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: - return ValueError('partial does not work correctly') - return a - - wrapped_fn = functools.partial(fn, test_arg2=456) - double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) - - self.assertEqual(('a',), util.fn_args(double_wrapped_fn)) - - def test_double_partial_with_positional_args_in_outer_layer(self): - expected_test_arg1 = 123 - expected_test_arg2 = 456 - - def fn(test_arg1, a, test_arg2): - if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: - return ValueError('partial fn does not work correctly') - return a - - wrapped_fn = functools.partial(fn, test_arg2=456) - double_wrapped_fn = functools.partial(wrapped_fn, 123) - - self.assertEqual(('a',), util.fn_args(double_wrapped_fn)) - - self.assertEqual(3, double_wrapped_fn(3)) - self.assertEqual(3, double_wrapped_fn(a=3)) - - def test_double_partial_with_positional_args_in_both_layers(self): - expected_test_arg1 = 123 - expected_test_arg2 = 456 - - def fn(test_arg1, test_arg2, a): - if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: - return ValueError('partial fn does not work correctly') - return a - - wrapped_fn = functools.partial(fn, 123) # binds to test_arg1 - double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2 - - self.assertEqual(('a',), util.fn_args(double_wrapped_fn)) - - self.assertEqual(3, double_wrapped_fn(3)) - self.assertEqual(3, double_wrapped_fn(a=3)) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py index 16ee2952b2..72ab77fbbd 100644 --- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -25,7 +25,7 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.estimator import util as function_utils from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -44,6 +44,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import checkpointable +from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect @@ -146,7 +147,7 @@ class Layer(checkpointable.CheckpointableBase): # return tensors. When using graph execution, _losses is a list of ops. self._losses = [] self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name - self._call_fn_args = estimator_util.fn_args(self.call) + self._call_fn_args = function_utils.fn_args(self.call) self._compute_previous_mask = ('mask' in self._call_fn_args or hasattr(self, 'compute_mask')) self._uses_inputs_arg = True @@ -644,7 +645,7 @@ class Layer(checkpointable.CheckpointableBase): self._compute_previous_mask): previous_mask = collect_previous_mask(inputs) if not hasattr(self, '_call_fn_args'): - self._call_fn_args = estimator_util.fn_args(self.call) + self._call_fn_args = function_utils.fn_args(self.call) if ('mask' in self._call_fn_args and 'mask' not in kwargs and not generic_utils.is_all_none(previous_mask)): # The previous layer generated a mask, and mask was not explicitly pass diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 64db49c900..2040e0081e 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -20,12 +20,12 @@ from __future__ import print_function import copy from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.keras._impl.keras.engine import base_layer from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.util import function_utils from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -308,7 +308,7 @@ class Layer(base_layer.Layer): try: call_has_scope_arg = self._call_has_scope_arg except AttributeError: - self._call_fn_args = estimator_util.fn_args(self.call) + self._call_fn_args = function_utils.fn_args(self.call) self._call_has_scope_arg = 'scope' in self._call_fn_args call_has_scope_arg = self._call_has_scope_arg if call_has_scope_arg: diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index adb0f59948..f5970fdbb2 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -32,7 +32,6 @@ from six import iteritems from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -41,6 +40,7 @@ from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import function_utils from tensorflow.python.util import tf_contextlib from tensorflow.python.util.tf_export import tf_export @@ -422,7 +422,7 @@ class _VariableStore(object): "use_resource": use_resource, } # `fn_args` can handle functions, `functools.partial`, `lambda`. - if "constraint" in estimator_util.fn_args(custom_getter): + if "constraint" in function_utils.fn_args(custom_getter): custom_getter_kwargs["constraint"] = constraint return custom_getter(**custom_getter_kwargs) else: diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py index f584a009d9..fece3370f3 100644 --- a/tensorflow/python/training/monitored_session.py +++ b/tensorflow/python/training/monitored_session.py @@ -25,7 +25,6 @@ import sys import six from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.estimator import util from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -41,6 +40,7 @@ from tensorflow.python.training import queue_runner from tensorflow.python.training import saver as training_saver from tensorflow.python.training import session_manager as sm from tensorflow.python.training import session_run_hook +from tensorflow.python.util import function_utils from tensorflow.python.util.tf_export import tf_export @@ -620,7 +620,7 @@ class _MonitoredSession(object): `step_context`. It may also optionally have `self` for cases when it belongs to an object. """ - step_fn_arguments = util.fn_args(step_fn) + step_fn_arguments = function_utils.fn_args(step_fn) if step_fn_arguments != ('step_context',) and step_fn_arguments != ( 'self', 'step_context', diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py new file mode 100644 index 0000000000..7bbbde3cd2 --- /dev/null +++ b/tensorflow/python/util/function_utils.py @@ -0,0 +1,57 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to retrieve function args.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect + + +def _is_bounded_method(fn): + _, fn = tf_decorator.unwrap(fn) + return tf_inspect.ismethod(fn) and (fn.__self__ is not None) + + +def _is_callable_object(obj): + return hasattr(obj, '__call__') and tf_inspect.ismethod(obj.__call__) + + +def fn_args(fn): + """Get argument names for function-like object. + + Args: + fn: Function, or function-like object (e.g., result of `functools.partial`). + + Returns: + `tuple` of string argument names. + + Raises: + ValueError: if partial function has positionally bound arguments + """ + if isinstance(fn, functools.partial): + args = fn_args(fn.func) + args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] + else: + if _is_callable_object(fn): + fn = fn.__call__ + args = tf_inspect.getfullargspec(fn).args + if _is_bounded_method(fn): + args.remove('self') + return tuple(args) diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py new file mode 100644 index 0000000000..e78cf6a5b0 --- /dev/null +++ b/tensorflow/python/util/function_utils_test.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Estimator related util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.python.platform import test +from tensorflow.python.util import function_utils + + +class FnArgsTest(test.TestCase): + + def test_simple_function(self): + def fn(a, b): + return a + b + self.assertEqual(('a', 'b'), function_utils.fn_args(fn)) + + def test_callable(self): + + class Foo(object): + + def __call__(self, a, b): + return a + b + + self.assertEqual(('a', 'b'), function_utils.fn_args(Foo())) + + def test_bounded_method(self): + + class Foo(object): + + def bar(self, a, b): + return a + b + + self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar)) + + def test_partial_function(self): + expected_test_arg = 123 + + def fn(a, test_arg): + if test_arg != expected_test_arg: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, test_arg=123) + + self.assertEqual(('a',), function_utils.fn_args(wrapped_fn)) + + def test_partial_function_with_positional_args(self): + expected_test_arg = 123 + + def fn(test_arg, a): + if test_arg != expected_test_arg: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, 123) + + self.assertEqual(('a',), function_utils.fn_args(wrapped_fn)) + + self.assertEqual(3, wrapped_fn(3)) + self.assertEqual(3, wrapped_fn(a=3)) + + def test_double_partial(self): + expected_test_arg1 = 123 + expected_test_arg2 = 456 + + def fn(a, test_arg1, test_arg2): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial does not work correctly') + return a + + wrapped_fn = functools.partial(fn, test_arg2=456) + double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123) + + self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) + + def test_double_partial_with_positional_args_in_outer_layer(self): + expected_test_arg1 = 123 + expected_test_arg2 = 456 + + def fn(test_arg1, a, test_arg2): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, test_arg2=456) + double_wrapped_fn = functools.partial(wrapped_fn, 123) + + self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) + + self.assertEqual(3, double_wrapped_fn(3)) + self.assertEqual(3, double_wrapped_fn(a=3)) + + def test_double_partial_with_positional_args_in_both_layers(self): + expected_test_arg1 = 123 + expected_test_arg2 = 456 + + def fn(test_arg1, test_arg2, a): + if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2: + return ValueError('partial fn does not work correctly') + return a + + wrapped_fn = functools.partial(fn, 123) # binds to test_arg1 + double_wrapped_fn = functools.partial(wrapped_fn, 456) # binds to test_arg2 + + self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn)) + + self.assertEqual(3, double_wrapped_fn(3)) + self.assertEqual(3, double_wrapped_fn(a=3)) + +if __name__ == '__main__': + test.main() -- cgit v1.2.3