diff options
author | Francois Chollet <fchollet@google.com> | 2017-08-14 17:50:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-14 17:54:20 -0700 |
commit | 4329a3b1e47c3e154925be3300382baeef59bdbb (patch) | |
tree | a057285928ae46c5284399f28d75bc6665e7cbef /tensorflow/contrib/keras | |
parent | e49d8dd4bb656a3165621391c55fd28f43969b52 (diff) |
Make it possible to use layers from `tf.layers` directly in a Keras model.
It is largely a tiny refactor. One addition to the public API: add method `count_params` to base layer class (which allows us to use the `model.summary()` method with models built with core layers).
PiperOrigin-RevId: 165255776
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r-- | tensorflow/contrib/keras/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/keras/python/keras/engine/topology.py | 37 | ||||
-rw-r--r-- | tensorflow/contrib/keras/python/keras/integration_test.py | 52 | ||||
-rw-r--r-- | tensorflow/contrib/keras/python/keras/models.py | 3 |
4 files changed, 63 insertions, 31 deletions
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index b923fc7e9a..5ae19bea33 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -157,6 +157,8 @@ py_test( deps = [ ":keras", "//tensorflow/python:client_testlib", + "//tensorflow/python:layers", + "//tensorflow/python:nn", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index 7124820254..67883bfb24 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -48,8 +48,11 @@ except ImportError: yaml = None # pylint: enable=g-import-not-at-top -InputSpec = tf_base_layers.InputSpec # pylint: disable=invalid-name -Node = tf_base_layers.Node # pylint: disable=invalid-name +# pylint: disable=invalid-name +InputSpec = tf_base_layers.InputSpec +Node = tf_base_layers.Node +TFBaseLayer = tf_base_layers.Layer +# pylint: enable=invalid-name class Layer(tf_base_layers.Layer): @@ -470,26 +473,6 @@ class Layer(tf_base_layers.Layer): """ return cls(**config) - def count_params(self): - """Count the total number of scalars composing the weights. - - Returns: - An integer count. - - Raises: - RuntimeError: if the layer isn't yet built - (in which case its weights aren't yet defined). - """ - if not self.built: - if self.__class__.__name__ == 'Sequential': - self.build() # pylint: disable=no-value-for-parameter - else: - raise RuntimeError('You tried to call `count_params` on ' + self.name + - ', but the layer isn\'t built. ' - 'You can build it manually via: `' + self.name + - '.build(batch_input_shape)`.') - return sum([K.count_params(p) for p in self.weights]) - class InputLayer(tf_base_layers.InputLayer, Layer): """Layer to be used as an entry point into a graph. @@ -543,13 +526,6 @@ class InputLayer(tf_base_layers.InputLayer, Layer): sparse=sparse, name=name) - if input_tensor is not None: - self.is_placeholder = False - self.batch_input_shape = tuple(input_tensor.get_shape().as_list()) - else: - self.is_placeholder = True - self.batch_input_shape = (batch_size,) + tuple(input_shape) - def get_config(self): config = { 'batch_input_shape': self.batch_input_shape, @@ -727,7 +703,8 @@ class Network(tf_base_layers.Network, Layer): @property def uses_learning_phase(self): - return any([x._uses_learning_phase for x in self.outputs]) + return any( + [getattr(x, '_uses_learning_phase', False) for x in self.outputs]) @property def stateful(self): diff --git a/tensorflow/contrib/keras/python/keras/integration_test.py b/tensorflow/contrib/keras/python/keras/integration_test.py index 32b0a95fe3..5c42ffcfbd 100644 --- a/tensorflow/contrib/keras/python/keras/integration_test.py +++ b/tensorflow/contrib/keras/python/keras/integration_test.py @@ -22,6 +22,9 @@ import numpy as np from tensorflow.contrib.keras.python import keras from tensorflow.contrib.keras.python.keras import testing_utils +from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.layers import core as tf_core_layers +from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -235,6 +238,55 @@ class KerasIntegrationTest(test.TestCase): model.compile(optimizer=keras.optimizers.SGD(clipnorm=0.1), loss='mse') model.fit(np.array([[0]]), np.array([[[0.5]]]), epochs=1) + def test_using_tf_layers_in_keras_sequential_model(self): + with self.test_session(): + np.random.seed(1337) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=200, + test_samples=100, + input_shape=(10,), + num_classes=2) + + model = keras.models.Sequential() + model.add(tf_core_layers.Dense(32, activation=nn.relu, input_shape=(10,))) + model.add(tf_core_layers.Dense(2, activation=nn.softmax)) + model.summary() + + y_train = keras.utils.to_categorical(y_train) + y_test = keras.utils.to_categorical(y_test) + model.compile(loss='categorical_crossentropy', + optimizer='adam', + metrics=['accuracy']) + history = model.fit(x_train, y_train, epochs=10, batch_size=16, + validation_data=(x_test, y_test), + verbose=0) + self.assertGreater(history.history['val_acc'][-1], 0.85) + + def test_using_tf_layers_in_keras_functional_model(self): + with self.test_session(): + np.random.seed(1337) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=200, + test_samples=100, + input_shape=(10,), + num_classes=2) + y_train = keras.utils.to_categorical(y_train) + y_test = keras.utils.to_categorical(y_test) + + inputs = tf_base_layers.Input(shape=(10,)) + x = tf_core_layers.Dense(32, activation=nn.relu)(inputs) + outputs = tf_core_layers.Dense(2, activation=nn.softmax)(x) + model = keras.models.Model(inputs, outputs) + model.summary() + + model.compile(loss='categorical_crossentropy', + optimizer='adam', + metrics=['accuracy']) + history = model.fit(x_train, y_train, epochs=10, batch_size=16, + validation_data=(x_test, y_test), + verbose=0) + self.assertGreater(history.history['val_acc'][-1], 0.85) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py index 813a462f55..c01d2c45e0 100644 --- a/tensorflow/contrib/keras/python/keras/models.py +++ b/tensorflow/contrib/keras/python/keras/models.py @@ -32,6 +32,7 @@ from tensorflow.contrib.keras.python.keras import optimizers from tensorflow.contrib.keras.python.keras.engine import topology from tensorflow.contrib.keras.python.keras.engine.topology import Input from tensorflow.contrib.keras.python.keras.engine.topology import Layer +from tensorflow.contrib.keras.python.keras.engine.topology import TFBaseLayer from tensorflow.contrib.keras.python.keras.engine.training import Model from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.framework import ops @@ -455,7 +456,7 @@ class Sequential(Model): multiple output tensors, or is already connected somewhere else (forbidden in `Sequential` models). """ - if not isinstance(layer, Layer): + if not isinstance(layer, (Layer, TFBaseLayer)): raise TypeError('The added layer must be ' 'an instance of class Layer. ' 'Found: ' + str(layer)) |