aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Charles Nicholson <nicholsonc@google.com>2017-04-21 10:59:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-21 12:09:42 -0700
commit8e5041918f2e709ded94e63fb1779d6bb363becb (patch)
tree5ec03d51d2e15b1247b68298b2ccd9138075513e /tensorflow/contrib/keras
parentc3bf39b7a6c3cc41f209ac863c764498b503d4f5 (diff)
Introduce TFDecorator, a base class for Python TensorFlow decorators. Provides basic introspection and "unwrap" services, allowing tooling code to fully 'understand' the wrapped object.
Change: 153854044
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/python/keras/backend_test.py5
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology.py8
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/core.py4
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/wrappers.py4
-rw-r--r--tensorflow/contrib/keras/python/keras/testing_utils.py5
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/generic_utils.py6
-rw-r--r--tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py6
7 files changed, 19 insertions, 19 deletions
diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py
index fd9db1f327..2da5aee58e 100644
--- a/tensorflow/contrib/keras/python/keras/backend_test.py
+++ b/tensorflow/contrib/keras/python/keras/backend_test.py
@@ -18,12 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import inspect
-
import numpy as np
from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
def compare_single_input_op_to_numpy(keras_op,
@@ -207,7 +206,7 @@ class BackendLinearAlgebraTest(test.TestCase):
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
keras_kwargs={'axis': -1},
np_kwargs={'axis': -1})
- if 'keepdims' in inspect.getargspec(keras_op).args:
+ if 'keepdims' in tf_inspect.getargspec(keras_op).args:
compare_single_input_op_to_numpy(keras_op, np_op,
input_shape=(4, 7, 5),
keras_kwargs={'axis': 1,
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index 0d1812aaa2..7848e5982d 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import copy
-import inspect
import json
import os
import re
@@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils import conv_utils
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import tf_inspect
# pylint: disable=g-import-not-at-top
@@ -584,7 +584,7 @@ class Layer(object):
user_kwargs = copy.copy(kwargs)
if not _is_all_none(previous_mask):
# The previous layer generated a mask.
- if 'mask' in inspect.getargspec(self.call).args:
+ if 'mask' in tf_inspect.getargspec(self.call).args:
if 'mask' not in kwargs:
# If mask is explicitly passed to __call__,
# we should override the default mask.
@@ -2166,7 +2166,7 @@ class Container(Layer):
kwargs = {}
if len(computed_data) == 1:
computed_tensor, computed_mask = computed_data[0]
- if 'mask' in inspect.getargspec(layer.call).args:
+ if 'mask' in tf_inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
@@ -2177,7 +2177,7 @@ class Container(Layer):
else:
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
- if 'mask' in inspect.getargspec(layer.call).args:
+ if 'mask' in tf_inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_masks
output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py
index 8dd55aaa2e..32ada176a4 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import copy
-import inspect
import types as python_types
import numpy as np
@@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import tf_inspect
class Masking(Layer):
@@ -595,7 +595,7 @@ class Lambda(Layer):
def call(self, inputs, mask=None):
arguments = self.arguments
- arg_spec = inspect.getargspec(self.function)
+ arg_spec = tf_inspect.getargspec(self.function)
if 'mask' in arg_spec.args:
arguments['mask'] = mask
return self.function(inputs, **arguments)
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
index a322696514..ce6458fd0c 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py
+++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
@@ -20,12 +20,12 @@ from __future__ import division
from __future__ import print_function
import copy
-import inspect
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.engine import InputSpec
from tensorflow.contrib.keras.python.keras.engine import Layer
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.util import tf_inspect
class Wrapper(Layer):
@@ -284,7 +284,7 @@ class Bidirectional(Wrapper):
def call(self, inputs, training=None, mask=None):
kwargs = {}
- func_args = inspect.getargspec(self.layer.call).args
+ func_args = tf_inspect.getargspec(self.layer.call).args
if 'training' in func_args:
kwargs['training'] = training
if 'mask' in func_args:
diff --git a/tensorflow/contrib/keras/python/keras/testing_utils.py b/tensorflow/contrib/keras/python/keras/testing_utils.py
index baba5447d9..bf6f661adf 100644
--- a/tensorflow/contrib/keras/python/keras/testing_utils.py
+++ b/tensorflow/contrib/keras/python/keras/testing_utils.py
@@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import inspect
-
import numpy as np
from tensorflow.contrib.keras.python import keras
+from tensorflow.python.util import tf_inspect
def get_test_data(train_samples,
@@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
layer.set_weights(weights)
# test and instantiation from weights
- if 'weights' in inspect.getargspec(layer_cls.__init__):
+ if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
kwargs['weights'] = weights
layer = layer_cls(**kwargs)
diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
index 4c95c314b1..27cc23f232 100644
--- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import inspect
import marshal
import sys
import time
@@ -26,6 +25,8 @@ import types as python_types
import numpy as np
import six
+from tensorflow.python.util import tf_decorator
+from tensorflow.python.util import tf_inspect
_GLOBAL_CUSTOM_OBJECTS = {}
@@ -116,6 +117,7 @@ def get_custom_objects():
def serialize_keras_object(instance):
+ _, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
if hasattr(instance, 'get_config'):
@@ -149,7 +151,7 @@ def deserialize_keras_object(identifier,
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
if hasattr(cls, 'from_config'):
- arg_spec = inspect.getargspec(cls.from_config)
+ arg_spec = tf_inspect.getargspec(cls.from_config)
if 'custom_objects' in arg_spec.args:
custom_objects = custom_objects or {}
return cls.from_config(
diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
index 323c31aee8..9f8cea375b 100644
--- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
+++ b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py
@@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function
import copy
-import inspect
import types
import numpy as np
from tensorflow.contrib.keras.python.keras.models import Sequential
from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
+from tensorflow.python.util import tf_inspect
class BaseWrapper(object):
@@ -97,7 +97,7 @@ class BaseWrapper(object):
legal_params = []
for fn in legal_params_fns:
- legal_params += inspect.getargspec(fn)[0]
+ legal_params += tf_inspect.getargspec(fn)[0]
legal_params = set(legal_params)
for params_name in params:
@@ -182,7 +182,7 @@ class BaseWrapper(object):
"""
override = override or {}
res = {}
- fn_args = inspect.getargspec(fn)[0]
+ fn_args = tf_inspect.getargspec(fn)[0]
for name, value in self.sk_params.items():
if name in fn_args:
res.update({name: value})