diff options
author | Charles Nicholson <nicholsonc@google.com> | 2017-04-21 10:59:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-21 12:09:42 -0700 |
commit | 8e5041918f2e709ded94e63fb1779d6bb363becb (patch) | |
tree | 5ec03d51d2e15b1247b68298b2ccd9138075513e /tensorflow/contrib/keras | |
parent | c3bf39b7a6c3cc41f209ac863c764498b503d4f5 (diff) |
Introduce TFDecorator, a base class for Python TensorFlow decorators. Provides basic introspection and "unwrap" services, allowing tooling code to fully 'understand' the wrapped object.
Change: 153854044
Diffstat (limited to 'tensorflow/contrib/keras')
7 files changed, 19 insertions, 19 deletions
diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py index fd9db1f327..2da5aee58e 100644 --- a/tensorflow/contrib/keras/python/keras/backend_test.py +++ b/tensorflow/contrib/keras/python/keras/backend_test.py @@ -18,12 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect - import numpy as np from tensorflow.contrib.keras.python import keras from tensorflow.python.platform import test +from tensorflow.python.util import tf_inspect def compare_single_input_op_to_numpy(keras_op, @@ -207,7 +206,7 @@ class BackendLinearAlgebraTest(test.TestCase): compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5), keras_kwargs={'axis': -1}, np_kwargs={'axis': -1}) - if 'keepdims' in inspect.getargspec(keras_op).args: + if 'keepdims' in tf_inspect.getargspec(keras_op).args: compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5), keras_kwargs={'axis': 1, diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index 0d1812aaa2..7848e5982d 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import copy -import inspect import json import os import re @@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils import conv_utils from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import tf_inspect # pylint: disable=g-import-not-at-top @@ -584,7 +584,7 @@ class Layer(object): user_kwargs = copy.copy(kwargs) if not _is_all_none(previous_mask): # The previous layer generated a mask. - if 'mask' in inspect.getargspec(self.call).args: + if 'mask' in tf_inspect.getargspec(self.call).args: if 'mask' not in kwargs: # If mask is explicitly passed to __call__, # we should override the default mask. @@ -2166,7 +2166,7 @@ class Container(Layer): kwargs = {} if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] - if 'mask' in inspect.getargspec(layer.call).args: + if 'mask' in tf_inspect.getargspec(layer.call).args: if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = _to_list(layer.call(computed_tensor, **kwargs)) @@ -2177,7 +2177,7 @@ class Container(Layer): else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] - if 'mask' in inspect.getargspec(layer.call).args: + if 'mask' in tf_inspect.getargspec(layer.call).args: if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = _to_list(layer.call(computed_tensors, **kwargs)) diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py index 8dd55aaa2e..32ada176a4 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core.py +++ b/tensorflow/contrib/keras/python/keras/layers/core.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import copy -import inspect import types as python_types import numpy as np @@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import tf_inspect class Masking(Layer): @@ -595,7 +595,7 @@ class Lambda(Layer): def call(self, inputs, mask=None): arguments = self.arguments - arg_spec = inspect.getargspec(self.function) + arg_spec = tf_inspect.getargspec(self.function) if 'mask' in arg_spec.args: arguments['mask'] = mask return self.function(inputs, **arguments) diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py index a322696514..ce6458fd0c 100644 --- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py +++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py @@ -20,12 +20,12 @@ from __future__ import division from __future__ import print_function import copy -import inspect from tensorflow.contrib.keras.python.keras import backend as K from tensorflow.contrib.keras.python.keras.engine import InputSpec from tensorflow.contrib.keras.python.keras.engine import Layer from tensorflow.python.framework import tensor_shape +from tensorflow.python.util import tf_inspect class Wrapper(Layer): @@ -284,7 +284,7 @@ class Bidirectional(Wrapper): def call(self, inputs, training=None, mask=None): kwargs = {} - func_args = inspect.getargspec(self.layer.call).args + func_args = tf_inspect.getargspec(self.layer.call).args if 'training' in func_args: kwargs['training'] = training if 'mask' in func_args: diff --git a/tensorflow/contrib/keras/python/keras/testing_utils.py b/tensorflow/contrib/keras/python/keras/testing_utils.py index baba5447d9..bf6f661adf 100644 --- a/tensorflow/contrib/keras/python/keras/testing_utils.py +++ b/tensorflow/contrib/keras/python/keras/testing_utils.py @@ -18,11 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect - import numpy as np from tensorflow.contrib.keras.python import keras +from tensorflow.python.util import tf_inspect def get_test_data(train_samples, @@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, layer.set_weights(weights) # test and instantiation from weights - if 'weights' in inspect.getargspec(layer_cls.__init__): + if 'weights' in tf_inspect.getargspec(layer_cls.__init__): kwargs['weights'] = weights layer = layer_cls(**kwargs) diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py index 4c95c314b1..27cc23f232 100644 --- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py +++ b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import inspect import marshal import sys import time @@ -26,6 +25,8 @@ import types as python_types import numpy as np import six +from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect _GLOBAL_CUSTOM_OBJECTS = {} @@ -116,6 +117,7 @@ def get_custom_objects(): def serialize_keras_object(instance): + _, instance = tf_decorator.unwrap(instance) if instance is None: return None if hasattr(instance, 'get_config'): @@ -149,7 +151,7 @@ def deserialize_keras_object(identifier, if cls is None: raise ValueError('Unknown ' + printable_module_name + ': ' + class_name) if hasattr(cls, 'from_config'): - arg_spec = inspect.getargspec(cls.from_config) + arg_spec = tf_inspect.getargspec(cls.from_config) if 'custom_objects' in arg_spec.args: custom_objects = custom_objects or {} return cls.from_config( diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py index 323c31aee8..9f8cea375b 100644 --- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py +++ b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py @@ -19,13 +19,13 @@ from __future__ import division from __future__ import print_function import copy -import inspect import types import numpy as np from tensorflow.contrib.keras.python.keras.models import Sequential from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical +from tensorflow.python.util import tf_inspect class BaseWrapper(object): @@ -97,7 +97,7 @@ class BaseWrapper(object): legal_params = [] for fn in legal_params_fns: - legal_params += inspect.getargspec(fn)[0] + legal_params += tf_inspect.getargspec(fn)[0] legal_params = set(legal_params) for params_name in params: @@ -182,7 +182,7 @@ class BaseWrapper(object): """ override = override or {} res = {} - fn_args = inspect.getargspec(fn)[0] + fn_args = tf_inspect.getargspec(fn)[0] for name, value in self.sk_params.items(): if name in fn_args: res.update({name: value}) |