aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-08-14 17:50:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-14 17:54:20 -0700
commit4329a3b1e47c3e154925be3300382baeef59bdbb (patch)
treea057285928ae46c5284399f28d75bc6665e7cbef /tensorflow/contrib/keras
parente49d8dd4bb656a3165621391c55fd28f43969b52 (diff)
Make it possible to use layers from `tf.layers` directly in a Keras model.
It is largely a tiny refactor. One addition to the public API: add method `count_params` to base layer class (which allows us to use the `model.summary()` method with models built with core layers). PiperOrigin-RevId: 165255776
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/BUILD2
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology.py37
-rw-r--r--tensorflow/contrib/keras/python/keras/integration_test.py52
-rw-r--r--tensorflow/contrib/keras/python/keras/models.py3
4 files changed, 63 insertions, 31 deletions
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD
index b923fc7e9a..5ae19bea33 100644
--- a/tensorflow/contrib/keras/BUILD
+++ b/tensorflow/contrib/keras/BUILD
@@ -157,6 +157,8 @@ py_test(
deps = [
":keras",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:nn",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index 7124820254..67883bfb24 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -48,8 +48,11 @@ except ImportError:
yaml = None
# pylint: enable=g-import-not-at-top
-InputSpec = tf_base_layers.InputSpec # pylint: disable=invalid-name
-Node = tf_base_layers.Node # pylint: disable=invalid-name
+# pylint: disable=invalid-name
+InputSpec = tf_base_layers.InputSpec
+Node = tf_base_layers.Node
+TFBaseLayer = tf_base_layers.Layer
+# pylint: enable=invalid-name
class Layer(tf_base_layers.Layer):
@@ -470,26 +473,6 @@ class Layer(tf_base_layers.Layer):
"""
return cls(**config)
- def count_params(self):
- """Count the total number of scalars composing the weights.
-
- Returns:
- An integer count.
-
- Raises:
- RuntimeError: if the layer isn't yet built
- (in which case its weights aren't yet defined).
- """
- if not self.built:
- if self.__class__.__name__ == 'Sequential':
- self.build() # pylint: disable=no-value-for-parameter
- else:
- raise RuntimeError('You tried to call `count_params` on ' + self.name +
- ', but the layer isn\'t built. '
- 'You can build it manually via: `' + self.name +
- '.build(batch_input_shape)`.')
- return sum([K.count_params(p) for p in self.weights])
-
class InputLayer(tf_base_layers.InputLayer, Layer):
"""Layer to be used as an entry point into a graph.
@@ -543,13 +526,6 @@ class InputLayer(tf_base_layers.InputLayer, Layer):
sparse=sparse,
name=name)
- if input_tensor is not None:
- self.is_placeholder = False
- self.batch_input_shape = tuple(input_tensor.get_shape().as_list())
- else:
- self.is_placeholder = True
- self.batch_input_shape = (batch_size,) + tuple(input_shape)
-
def get_config(self):
config = {
'batch_input_shape': self.batch_input_shape,
@@ -727,7 +703,8 @@ class Network(tf_base_layers.Network, Layer):
@property
def uses_learning_phase(self):
- return any([x._uses_learning_phase for x in self.outputs])
+ return any(
+ [getattr(x, '_uses_learning_phase', False) for x in self.outputs])
@property
def stateful(self):
diff --git a/tensorflow/contrib/keras/python/keras/integration_test.py b/tensorflow/contrib/keras/python/keras/integration_test.py
index 32b0a95fe3..5c42ffcfbd 100644
--- a/tensorflow/contrib/keras/python/keras/integration_test.py
+++ b/tensorflow/contrib/keras/python/keras/integration_test.py
@@ -22,6 +22,9 @@ import numpy as np
from tensorflow.contrib.keras.python import keras
from tensorflow.contrib.keras.python.keras import testing_utils
+from tensorflow.python.layers import base as tf_base_layers
+from tensorflow.python.layers import core as tf_core_layers
+from tensorflow.python.ops import nn
from tensorflow.python.platform import test
@@ -235,6 +238,55 @@ class KerasIntegrationTest(test.TestCase):
model.compile(optimizer=keras.optimizers.SGD(clipnorm=0.1), loss='mse')
model.fit(np.array([[0]]), np.array([[[0.5]]]), epochs=1)
+ def test_using_tf_layers_in_keras_sequential_model(self):
+ with self.test_session():
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=200,
+ test_samples=100,
+ input_shape=(10,),
+ num_classes=2)
+
+ model = keras.models.Sequential()
+ model.add(tf_core_layers.Dense(32, activation=nn.relu, input_shape=(10,)))
+ model.add(tf_core_layers.Dense(2, activation=nn.softmax))
+ model.summary()
+
+ y_train = keras.utils.to_categorical(y_train)
+ y_test = keras.utils.to_categorical(y_test)
+ model.compile(loss='categorical_crossentropy',
+ optimizer='adam',
+ metrics=['accuracy'])
+ history = model.fit(x_train, y_train, epochs=10, batch_size=16,
+ validation_data=(x_test, y_test),
+ verbose=0)
+ self.assertGreater(history.history['val_acc'][-1], 0.85)
+
+ def test_using_tf_layers_in_keras_functional_model(self):
+ with self.test_session():
+ np.random.seed(1337)
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=200,
+ test_samples=100,
+ input_shape=(10,),
+ num_classes=2)
+ y_train = keras.utils.to_categorical(y_train)
+ y_test = keras.utils.to_categorical(y_test)
+
+ inputs = tf_base_layers.Input(shape=(10,))
+ x = tf_core_layers.Dense(32, activation=nn.relu)(inputs)
+ outputs = tf_core_layers.Dense(2, activation=nn.softmax)(x)
+ model = keras.models.Model(inputs, outputs)
+ model.summary()
+
+ model.compile(loss='categorical_crossentropy',
+ optimizer='adam',
+ metrics=['accuracy'])
+ history = model.fit(x_train, y_train, epochs=10, batch_size=16,
+ validation_data=(x_test, y_test),
+ verbose=0)
+ self.assertGreater(history.history['val_acc'][-1], 0.85)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py
index 813a462f55..c01d2c45e0 100644
--- a/tensorflow/contrib/keras/python/keras/models.py
+++ b/tensorflow/contrib/keras/python/keras/models.py
@@ -32,6 +32,7 @@ from tensorflow.contrib.keras.python.keras import optimizers
from tensorflow.contrib.keras.python.keras.engine import topology
from tensorflow.contrib.keras.python.keras.engine.topology import Input
from tensorflow.contrib.keras.python.keras.engine.topology import Layer
+from tensorflow.contrib.keras.python.keras.engine.topology import TFBaseLayer
from tensorflow.contrib.keras.python.keras.engine.training import Model
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.framework import ops
@@ -455,7 +456,7 @@ class Sequential(Model):
multiple output tensors, or is already connected
somewhere else (forbidden in `Sequential` models).
"""
- if not isinstance(layer, Layer):
+ if not isinstance(layer, (Layer, TFBaseLayer)):
raise TypeError('The added layer must be '
'an instance of class Layer. '
'Found: ' + str(layer))