aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-08-23 17:09:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-23 17:13:11 -0700
commitfa83a270c943317da6b07a3d093c224be0827bd9 (patch)
tree574e6b0feae781e071d85030f03d0f52823056bb /tensorflow/contrib/keras
parent1f41602a82cb68fc7bc7e51cf9590a87ee5baf4d (diff)
Increase tf.contrib.keras test coverage to 90%.
PiperOrigin-RevId: 166277862
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/BUILD25
-rw-r--r--tensorflow/contrib/keras/python/keras/activations_test.py11
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py16
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py16
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py21
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/resnet50_test.py9
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/vgg16_test.py8
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/vgg19_test.py8
-rw-r--r--tensorflow/contrib/keras/python/keras/applications/xception_test.py15
-rw-r--r--tensorflow/contrib/keras/python/keras/backend.py36
-rw-r--r--tensorflow/contrib/keras/python/keras/backend_test.py555
-rw-r--r--tensorflow/contrib/keras/python/keras/callbacks.py6
-rw-r--r--tensorflow/contrib/keras/python/keras/callbacks_test.py91
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology.py45
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology_test.py138
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/training.py3
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/training_test.py342
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/convolutional_test.py104
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/core.py14
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/core_test.py19
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/merge.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/merge_test.py21
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/normalization_test.py14
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/wrappers.py3
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/wrappers_test.py65
-rw-r--r--tensorflow/contrib/keras/python/keras/models.py20
-rw-r--r--tensorflow/contrib/keras/python/keras/models_test.py106
-rw-r--r--tensorflow/contrib/keras/python/keras/preprocessing/image_test.py27
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/conv_utils.py107
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/data_utils_test.py61
-rw-r--r--tensorflow/contrib/keras/python/keras/utils/io_utils_test.py100
31 files changed, 1635 insertions, 373 deletions
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD
index 5ae19bea33..7a562f727e 100644
--- a/tensorflow/contrib/keras/BUILD
+++ b/tensorflow/contrib/keras/BUILD
@@ -257,6 +257,7 @@ py_test(
deps = [
":keras",
"//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
],
)
@@ -268,6 +269,7 @@ py_test(
deps = [
":keras",
"//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
],
)
@@ -312,6 +314,7 @@ py_test(
deps = [
":keras",
"//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
],
)
@@ -426,14 +429,15 @@ py_test(
],
)
-cuda_py_test(
+py_test(
name = "normalization_test",
size = "small",
srcs = ["python/keras/layers/normalization_test.py"],
- additional_deps = [
+ srcs_version = "PY2AND3",
+ deps = [
":keras",
- "//third_party/py/numpy",
"//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
],
)
@@ -543,6 +547,18 @@ py_test(
)
py_test(
+ name = "io_utils_test",
+ size = "small",
+ srcs = ["python/keras/utils/io_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":keras",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "imagenet_utils_test",
size = "small",
srcs = ["python/keras/applications/imagenet_utils_test.py"],
@@ -605,7 +621,7 @@ py_test(
py_test(
name = "training_test",
- size = "small",
+ size = "medium",
srcs = ["python/keras/engine/training_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
@@ -638,6 +654,7 @@ py_test(
deps = [
":keras",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:training",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/keras/python/keras/activations_test.py b/tensorflow/contrib/keras/python/keras/activations_test.py
index 3d21610e49..8efa464b03 100644
--- a/tensorflow/contrib/keras/python/keras/activations_test.py
+++ b/tensorflow/contrib/keras/python/keras/activations_test.py
@@ -54,6 +54,10 @@ class KerasActivationsTest(test.TestCase):
expected = _ref_softmax(test_values[0])
self.assertAllClose(result[0], expected, rtol=1e-05)
+ with self.assertRaises(ValueError):
+ x = keras.backend.placeholder(ndim=1)
+ keras.activations.softmax(x)
+
def test_temporal_softmax(self):
with self.test_session():
x = keras.backend.placeholder(shape=(2, 2, 3))
@@ -169,5 +173,12 @@ class KerasActivationsTest(test.TestCase):
x = np.random.random((10, 5))
self.assertAllClose(x, keras.activations.linear(x))
+ def test_invalid_usage(self):
+ with self.assertRaises(ValueError):
+ keras.activations.get('unknown')
+
+ # The following should be possible but should raise a warning:
+ keras.activations.get(keras.layers.LeakyReLU())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py b/tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py
index f3bcf93a95..378c06d30d 100644
--- a/tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py
+++ b/tensorflow/contrib/keras/python/keras/applications/imagenet_utils_test.py
@@ -37,18 +37,6 @@ class ImageNetUtilsTest(test.TestCase):
np.transpose(x, (0, 3, 1, 2)), 'channels_first')
self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))
- def test_decode_predictions(self):
- x = np.zeros((2, 1000))
- x[0, 372] = 1.0
- x[1, 549] = 1.0
- outs = keras.applications.imagenet_utils.decode_predictions(x, top=1)
- scores = [out[0][2] for out in outs]
- self.assertEqual(scores[0], scores[1])
-
- # the numbers of columns and ImageNet classes are not identical.
- with self.assertRaises(ValueError):
- keras.applications.imagenet_utils.decode_predictions(np.ones((2, 100)))
-
def test_obtain_input_shape(self):
# input_shape and default_size are not identical.
with self.assertRaises(ValueError):
@@ -137,3 +125,7 @@ class ImageNetUtilsTest(test.TestCase):
min_size=139,
data_format='channels_first',
include_top=False) == (3, None, None)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py b/tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py
index 586f0da270..890df612ff 100644
--- a/tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py
+++ b/tensorflow/contrib/keras/python/keras/applications/inception_v3_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
@@ -38,5 +40,19 @@ class InceptionV3Test(test.TestCase):
pooling='avg')
self.assertEqual(model.output_shape, (None, 2048))
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.InceptionV3(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.InceptionV3(weights='imagenet',
+ classes=2000)
+
+ def test_preprocess_input(self):
+ x = np.random.uniform(0, 255, (2, 300, 200, 3))
+ out1 = keras.applications.inception_v3.preprocess_input(x)
+ self.assertAllClose(np.mean(out1), 0., atol=0.1)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py b/tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py
index 6aa786f9b1..d67964c02b 100644
--- a/tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py
+++ b/tensorflow/contrib/keras/python/keras/applications/mobilenet_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
@@ -38,5 +40,24 @@ class MobileNetTest(test.TestCase):
pooling='avg')
self.assertEqual(model.output_shape, (None, 1024))
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.MobileNet(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.MobileNet(weights='imagenet',
+ classes=2000)
+
+ def test_preprocess_input(self):
+ x = np.random.uniform(0, 255, (2, 300, 200, 3))
+ out1 = keras.applications.mobilenet.preprocess_input(x)
+ self.assertAllClose(np.mean(out1), 0., atol=0.1)
+
+ def test_invalid_use_cases(self):
+ keras.backend.set_image_data_format('channels_first')
+ model = keras.applications.MobileNet(weights=None)
+ self.assertEqual(model.output_shape, (None, 1000))
+ keras.backend.set_image_data_format('channels_last')
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/resnet50_test.py b/tensorflow/contrib/keras/python/keras/applications/resnet50_test.py
index 0ef701af93..2b00170652 100644
--- a/tensorflow/contrib/keras/python/keras/applications/resnet50_test.py
+++ b/tensorflow/contrib/keras/python/keras/applications/resnet50_test.py
@@ -38,5 +38,14 @@ class ResNet50Test(test.TestCase):
pooling='avg')
self.assertEqual(model.output_shape, (None, 2048))
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.ResNet50(weights='unknown',
+ include_top=False)
+
+ with self.assertRaises(ValueError):
+ keras.applications.ResNet50(weights='imagenet',
+ classes=2000)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg16_test.py b/tensorflow/contrib/keras/python/keras/applications/vgg16_test.py
index d0e707d675..4ba5dabd5a 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg16_test.py
+++ b/tensorflow/contrib/keras/python/keras/applications/vgg16_test.py
@@ -38,5 +38,13 @@ class VGG16Test(test.TestCase):
pooling='avg')
self.assertEqual(model.output_shape, (None, 512))
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.VGG16(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.VGG16(weights='imagenet',
+ classes=2000)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/vgg19_test.py b/tensorflow/contrib/keras/python/keras/applications/vgg19_test.py
index f2db0da4f4..604d4bb2d8 100644
--- a/tensorflow/contrib/keras/python/keras/applications/vgg19_test.py
+++ b/tensorflow/contrib/keras/python/keras/applications/vgg19_test.py
@@ -38,5 +38,13 @@ class VGG19Test(test.TestCase):
pooling='avg')
self.assertEqual(model.output_shape, (None, 512))
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.VGG19(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.VGG19(weights='imagenet',
+ classes=2000)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/applications/xception_test.py b/tensorflow/contrib/keras/python/keras/applications/xception_test.py
index bb3cc1678e..a941514c3e 100644
--- a/tensorflow/contrib/keras/python/keras/applications/xception_test.py
+++ b/tensorflow/contrib/keras/python/keras/applications/xception_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
@@ -38,5 +40,18 @@ class XceptionTest(test.TestCase):
pooling='avg')
self.assertEqual(model.output_shape, (None, 2048))
+ def test_weight_loading(self):
+ with self.assertRaises(ValueError):
+ keras.applications.Xception(weights='unknown',
+ include_top=False)
+ with self.assertRaises(ValueError):
+ keras.applications.Xception(weights='imagenet',
+ classes=2000)
+
+ def test_preprocess_input(self):
+ x = np.random.uniform(0, 255, (2, 300, 200, 3))
+ out1 = keras.applications.xception.preprocess_input(x)
+ self.assertAllClose(np.mean(out1), 0., atol=0.1)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py
index 6d7429d20d..99570797af 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/contrib/keras/python/keras/backend.py
@@ -653,7 +653,7 @@ def int_shape(x):
"""
shape = x.get_shape()
try:
- return tuple([i.__int__() for i in shape])
+ return tuple(shape.as_list())
except ValueError:
return None
@@ -3118,40 +3118,6 @@ def _preprocess_conv3d_input(x, data_format):
return x
-def _preprocess_conv2d_kernel(kernel, data_format):
- """Transpose and cast the kernel before the conv2d.
-
- Arguments:
- kernel: kernel tensor.
- data_format: string, one of 'channels_last', 'channels_first'.
-
- Returns:
- A tensor.
- """
- if dtype(kernel) == 'float64':
- kernel = math_ops.cast(kernel, 'float32')
- if data_format == 'channels_first':
- kernel = array_ops.transpose(kernel, (2, 3, 1, 0))
- return kernel
-
-
-def _preprocess_conv3d_kernel(kernel, data_format):
- """Transpose and cast the kernel before the conv3d.
-
- Arguments:
- kernel: kernel tensor.
- data_format: string, one of 'channels_last', 'channels_first'.
-
- Returns:
- A tensor.
- """
- if dtype(kernel) == 'float64':
- kernel = math_ops.cast(kernel, 'float32')
- if data_format == 'channels_first':
- kernel = array_ops.transpose(kernel, (2, 3, 4, 1, 0))
- return kernel
-
-
def _preprocess_padding(padding):
"""Convert keras' padding to tensorflow's padding.
diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py
index a2bc95e4a1..69dcf3f094 100644
--- a/tensorflow/contrib/keras/python/keras/backend_test.py
+++ b/tensorflow/contrib/keras/python/keras/backend_test.py
@@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+import scipy.sparse
from tensorflow.contrib.keras.python import keras
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -112,6 +114,31 @@ class BackendUtilsTest(test.TestCase):
keras.backend.reset_uids()
self.assertEqual(keras.backend.get_uid('foo'), 1)
+ def test_learning_phase(self):
+ with self.test_session():
+ keras.backend.set_learning_phase(1)
+ self.assertEqual(keras.backend.learning_phase(), 1)
+ with self.assertRaises(ValueError):
+ keras.backend.set_learning_phase(2)
+
+ def test_int_shape(self):
+ x = keras.backend.placeholder(shape=(3, 4))
+ self.assertEqual(keras.backend.int_shape(x), (3, 4))
+
+ x = keras.backend.placeholder(shape=(None, 4))
+ self.assertEqual(keras.backend.int_shape(x), (None, 4))
+
+ def test_in_train_phase(self):
+ with self.test_session():
+ y1 = keras.backend.variable(1)
+ y2 = keras.backend.variable(2)
+ y = keras.backend.in_train_phase(y1, y2)
+ f = keras.backend.function([keras.backend.learning_phase()], [y])
+ y_val = f([0])[0]
+ self.assertAllClose(y_val, 2)
+ y_val = f([1])[0]
+ self.assertAllClose(y_val, 1)
+
class BackendVariableTest(test.TestCase):
@@ -169,6 +196,28 @@ class BackendVariableTest(test.TestCase):
val = keras.backend.count_params(x)
self.assertAllClose(val, 20)
+ def test_constant(self):
+ with self.test_session():
+ ref_val = np.random.random((3, 4)).astype('float32')
+ x = keras.backend.constant(ref_val)
+ val = keras.backend.eval(x)
+ self.assertAllClose(val, ref_val)
+
+ def test_sparse_variable(self):
+ with self.test_session():
+ val = scipy.sparse.eye(10)
+ x = keras.backend.variable(val)
+ self.assertTrue(isinstance(x, sparse_tensor.SparseTensor))
+
+ y = keras.backend.to_dense(x)
+ self.assertFalse(keras.backend.is_sparse(y))
+
+ def test_placeholder(self):
+ x = keras.backend.placeholder(shape=(3, 4))
+ self.assertEqual(x.get_shape().as_list(), [3, 4])
+ x = keras.backend.placeholder(shape=(3, 4), sparse=True)
+ self.assertEqual(x.get_shape().as_list(), [3, 4])
+
class BackendLinearAlgebraTest(test.TestCase):
@@ -189,6 +238,8 @@ class BackendLinearAlgebraTest(test.TestCase):
xy = keras.backend.batch_dot(x, y, axes=[1, 2])
self.assertEqual(xy.get_shape().as_list(), [32, 1, 30])
+ # TODO(fchollet): insufficiently tested.
+
def test_reduction_ops(self):
ops_to_test = [
(keras.backend.max, np.max),
@@ -315,6 +366,13 @@ class BackendShapeOpsTest(test.TestCase):
data_format)
self.assertEqual(y.get_shape().as_list(), [1, 3, 4, 4])
+ # Invalid use:
+ with self.assertRaises(ValueError):
+ keras.backend.resize_images(x,
+ height_factor,
+ width_factor,
+ data_format='unknown')
+
def test_resize_volumes(self):
height_factor = 2
width_factor = 2
@@ -337,11 +395,24 @@ class BackendShapeOpsTest(test.TestCase):
data_format)
self.assertEqual(y.get_shape().as_list(), [1, 3, 4, 4, 4])
+ # Invalid use:
+ with self.assertRaises(ValueError):
+ keras.backend.resize_volumes(x,
+ depth_factor,
+ height_factor,
+ width_factor,
+ data_format='unknown')
+
def test_repeat_elements(self):
x = keras.backend.variable(np.ones((1, 3, 2)))
y = keras.backend.repeat_elements(x, 3, axis=1)
self.assertEqual(y.get_shape().as_list(), [1, 9, 2])
+ # Invalid use:
+ with self.assertRaises(ValueError):
+ x = keras.backend.placeholder(shape=(2, None, 2))
+ keras.backend.repeat_elements(x, 3, axis=1)
+
def test_repeat(self):
x = keras.backend.variable(np.ones((1, 3)))
y = keras.backend.repeat(x, 2)
@@ -455,5 +526,489 @@ class BackendShapeOpsTest(test.TestCase):
np_kwargs={'data_format': 'channels_first'})
+class BackendNNOpsTest(test.TestCase):
+
+ def test_bias_add(self):
+ with self.test_session():
+ keras_op = keras.backend.bias_add
+ np_op = np.add
+ compare_two_inputs_op_to_numpy(keras_op, np_op,
+ input_shape_a=(4, 7),
+ input_shape_b=(7,))
+ compare_two_inputs_op_to_numpy(keras_op, np_op,
+ input_shape_a=(4, 3, 7),
+ input_shape_b=(7,))
+ compare_two_inputs_op_to_numpy(keras_op, np_op,
+ input_shape_a=(4, 3, 5, 7),
+ input_shape_b=(7,))
+ compare_two_inputs_op_to_numpy(keras_op, np_op,
+ input_shape_a=(4, 3, 5, 2, 7),
+ input_shape_b=(7,))
+
+ with self.assertRaises(ValueError):
+ x = keras.backend.variable((3, 4))
+ b = keras.backend.variable((3, 4))
+ keras.backend.bias_add(x, b)
+ with self.assertRaises(ValueError):
+ x = keras.backend.variable((3, 4))
+ b = keras.backend.variable((4,))
+ keras.backend.bias_add(x, b, data_format='unknown')
+
+ def test_bias_add_channels_first(self):
+ with self.test_session():
+ def keras_op(x, b):
+ return keras.backend.bias_add(x, b, data_format='channels_first')
+
+ def np_op(x, b):
+ if x.ndim == 3:
+ b = b.reshape((1, b.shape[0], 1))
+ if x.ndim == 4:
+ b = b.reshape((1, b.shape[0], 1, 1))
+ return x + b
+
+ compare_two_inputs_op_to_numpy(keras_op, np_op,
+ input_shape_a=(4, 3, 7),
+ input_shape_b=(3,))
+ compare_two_inputs_op_to_numpy(keras_op, np_op,
+ input_shape_a=(4, 3, 5, 7),
+ input_shape_b=(3,))
+
+ def test_pool2d(self):
+ val = np.random.random((10, 3, 10, 10))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
+ padding='valid', data_format='channels_first',
+ pool_mode='max')
+ self.assertEqual(y.get_shape().as_list(), [10, 3, 9, 9])
+
+ y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
+ padding='valid', data_format='channels_first',
+ pool_mode='avg')
+ self.assertEqual(y.get_shape().as_list(), [10, 3, 9, 9])
+
+ val = np.random.random((10, 10, 10, 3))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
+ padding='valid', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 9, 9, 3])
+
+ val = np.random.random((10, 10, 10, 3))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool2d(x, (2, 2), strides=(1, 1),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 10, 10, 3])
+
+ val = np.random.random((10, 10, 10, 3))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool2d(x, (2, 2), strides=(2, 2),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 5, 3])
+
+ with self.assertRaises(ValueError):
+ y = keras.backend.pool2d(x, (2, 2), strides=(2, 2),
+ padding='other', data_format='channels_last')
+ with self.assertRaises(ValueError):
+ y = keras.backend.pool2d(x, (2, 2), strides=(2, 2),
+ data_format='other')
+ with self.assertRaises(ValueError):
+ y = keras.backend.pool2d(x, (2, 2, 2), strides=(2, 2))
+ with self.assertRaises(ValueError):
+ y = keras.backend.pool2d(x, (2, 2), strides=(2, 2, 2))
+ with self.assertRaises(ValueError):
+ y = keras.backend.pool2d(x, (2, 2), strides=(2, 2), pool_mode='other')
+
+ def test_pool3d(self):
+ val = np.random.random((10, 3, 10, 10, 10))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
+ padding='valid', data_format='channels_first',
+ pool_mode='max')
+ self.assertEqual(y.get_shape().as_list(), [10, 3, 9, 9, 9])
+
+ y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
+ padding='valid', data_format='channels_first',
+ pool_mode='avg')
+ self.assertEqual(y.get_shape().as_list(), [10, 3, 9, 9, 9])
+
+ val = np.random.random((10, 10, 10, 10, 3))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
+ padding='valid', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 9, 9, 9, 3])
+
+ val = np.random.random((10, 10, 10, 10, 3))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool3d(x, (2, 2, 2), strides=(1, 1, 1),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 10, 10, 10, 3])
+
+ val = np.random.random((10, 10, 10, 10, 3))
+ x = keras.backend.variable(val)
+ y = keras.backend.pool3d(x, (2, 2, 2), strides=(2, 2, 2),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 5, 5, 3])
+
+ def test_conv1d(self):
+ val = np.random.random((10, 4, 10))
+ x = keras.backend.variable(val)
+ kernel_val = np.random.random((3, 4, 5))
+ k = keras.backend.variable(kernel_val)
+ y = keras.backend.conv1d(x, k, strides=(1,),
+ padding='valid', data_format='channels_first')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 8])
+
+ val = np.random.random((10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv1d(x, k, strides=(1,),
+ padding='valid', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 8, 5])
+
+ val = np.random.random((10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv1d(x, k, strides=(1,),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 10, 5])
+
+ val = np.random.random((10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv1d(x, k, strides=(2,),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 5])
+
+ def test_conv2d(self):
+ val = np.random.random((10, 4, 10, 10))
+ x = keras.backend.variable(val)
+ kernel_val = np.random.random((3, 3, 4, 5))
+ k = keras.backend.variable(kernel_val)
+ y = keras.backend.conv2d(x, k,
+ padding='valid', data_format='channels_first')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 8, 8])
+
+ val = np.random.random((10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv2d(x, k, strides=(1, 1),
+ padding='valid', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 8, 8, 5])
+
+ val = np.random.random((10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv2d(x, k, strides=(1, 1),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 10, 10, 5])
+
+ val = np.random.random((10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv2d(x, k, strides=(2, 2),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 5, 5])
+ with self.assertRaises(ValueError):
+ y = keras.backend.conv2d(x, k, (2, 2),
+ padding='other', data_format='channels_last')
+ with self.assertRaises(ValueError):
+ y = keras.backend.conv2d(x, k, (2, 2),
+ data_format='other')
+ with self.assertRaises(ValueError):
+ y = keras.backend.conv2d(x, k, (2, 2, 2))
+
+ def test_separable_conv2d(self):
+ val = np.random.random((10, 4, 10, 10))
+ x = keras.backend.variable(val)
+ depthwise_kernel_val = np.random.random((3, 3, 4, 1))
+ pointwise_kernel_val = np.random.random((1, 1, 4, 5))
+ dk = keras.backend.variable(depthwise_kernel_val)
+ pk = keras.backend.variable(pointwise_kernel_val)
+ y = keras.backend.separable_conv2d(
+ x, dk, pk, padding='valid', data_format='channels_first')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 8, 8])
+
+ val = np.random.random((10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.separable_conv2d(
+ x, dk, pk, strides=(1, 1), padding='valid', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 8, 8, 5])
+
+ val = np.random.random((10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.separable_conv2d(
+ x, dk, pk, strides=(1, 1), padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 10, 10, 5])
+
+ val = np.random.random((10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.separable_conv2d(
+ x, dk, pk, strides=(2, 2), padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 5, 5])
+ with self.assertRaises(ValueError):
+ y = keras.backend.separable_conv2d(
+ x, dk, pk, (2, 2), padding='other', data_format='channels_last')
+ with self.assertRaises(ValueError):
+ y = keras.backend.separable_conv2d(
+ x, dk, pk, (2, 2), data_format='other')
+ with self.assertRaises(ValueError):
+ y = keras.backend.separable_conv2d(x, dk, pk, (2, 2, 2))
+
+ def test_conv3d(self):
+ val = np.random.random((10, 4, 10, 10, 10))
+ x = keras.backend.variable(val)
+ kernel_val = np.random.random((3, 3, 3, 4, 5))
+ k = keras.backend.variable(kernel_val)
+ y = keras.backend.conv3d(x, k,
+ padding='valid', data_format='channels_first')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 8, 8, 8])
+
+ val = np.random.random((10, 10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv3d(x, k, strides=(1, 1, 1),
+ padding='valid', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 8, 8, 8, 5])
+
+ val = np.random.random((10, 10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv3d(x, k, strides=(1, 1, 1),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 10, 10, 10, 5])
+
+ val = np.random.random((10, 10, 10, 10, 4))
+ x = keras.backend.variable(val)
+ y = keras.backend.conv3d(x, k, strides=(2, 2, 2),
+ padding='same', data_format='channels_last')
+ self.assertEqual(y.get_shape().as_list(), [10, 5, 5, 5, 5])
+ with self.assertRaises(ValueError):
+ y = keras.backend.conv3d(x, k, (2, 2, 2),
+ padding='other', data_format='channels_last')
+ with self.assertRaises(ValueError):
+ y = keras.backend.conv3d(x, k, (2, 2, 2),
+ data_format='other')
+ with self.assertRaises(ValueError):
+ y = keras.backend.conv3d(x, k, (2, 2))
+
+ def test_rnn(self):
+ # implement a simple RNN
+ num_samples = 4
+ input_dim = 5
+ output_dim = 3
+ timesteps = 6
+
+ input_val = np.random.random(
+ (num_samples, timesteps, input_dim)).astype(np.float32)
+ init_state_val = np.random.random(
+ (num_samples, output_dim)).astype(np.float32)
+ w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
+ w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
+ np_mask = np.random.randint(2, size=(num_samples, timesteps))
+
+ def rnn_step_fn():
+ w_i = keras.backend.variable(w_i_val)
+ w_o = keras.backend.variable(w_o_val)
+
+ def step_function(x, states):
+ assert len(states) == 1
+ prev_output = states[0]
+ output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o)
+ return output, [output]
+
+ return step_function
+
+ # test default setup
+ last_output_list = [[], [], [], [], [], []]
+ outputs_list = [[], [], [], [], [], []]
+ state_list = [[], [], [], [], [], []]
+
+ rnn_fn = rnn_step_fn()
+ inputs = keras.backend.variable(input_val)
+ initial_states = [keras.backend.variable(init_state_val)]
+ mask = keras.backend.variable(np_mask)
+
+ kwargs_list = [
+ {'go_backwards': False, 'mask': None},
+ {'go_backwards': False, 'mask': None, 'unroll': True},
+ {'go_backwards': True, 'mask': None},
+ {'go_backwards': True, 'mask': None, 'unroll': True},
+ {'go_backwards': False, 'mask': mask},
+ {'go_backwards': False, 'mask': mask, 'unroll': True},
+ ]
+
+ for (i, kwargs) in enumerate(kwargs_list):
+ last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
+ initial_states,
+ **kwargs)
+ last_output_list[i].append(keras.backend.eval(last_output))
+ outputs_list[i].append(keras.backend.eval(outputs))
+ self.assertEqual(len(new_states), 1)
+ state_list[i].append(keras.backend.eval(new_states[0]))
+
+ def assert_list_pairwise(z_list, atol=1e-05):
+ for (z1, z2) in zip(z_list[1:], z_list[:-1]):
+ self.assertAllClose(z1, z2, atol=atol)
+
+ assert_list_pairwise(last_output_list[0], atol=1e-04)
+ assert_list_pairwise(outputs_list[0], atol=1e-04)
+ assert_list_pairwise(state_list[0], atol=1e-04)
+ assert_list_pairwise(last_output_list[2], atol=1e-04)
+ assert_list_pairwise(outputs_list[2], atol=1e-04)
+ assert_list_pairwise(state_list[2], atol=1e-04)
+
+ for l, u_l in zip(last_output_list[0], last_output_list[1]):
+ self.assertAllClose(l, u_l, atol=1e-04)
+
+ for o, u_o in zip(outputs_list[0], outputs_list[1]):
+ self.assertAllClose(o, u_o, atol=1e-04)
+
+ for s, u_s in zip(state_list[0], state_list[1]):
+ self.assertAllClose(s, u_s, atol=1e-04)
+
+ for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
+ self.assertAllClose(b_l, b_u_l, atol=1e-04)
+
+ for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
+ self.assertAllClose(b_o, b_u_o, atol=1e-04)
+
+ for b_s, b_u_s in zip(state_list[2], state_list[3]):
+ self.assertAllClose(b_s, b_u_s, atol=1e-04)
+
+ def test_normalize_batch_in_training(self):
+ val = np.random.random((10, 3, 10, 10))
+ x = keras.backend.variable(val)
+ reduction_axes = (0, 2, 3)
+
+ # case: need broadcasting
+ g_val = np.random.random((3,))
+ b_val = np.random.random((3,))
+ gamma = keras.backend.variable(g_val)
+ beta = keras.backend.variable(b_val)
+ normed, mean, var = keras.backend.normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=1e-3)
+ self.assertEqual(normed.get_shape().as_list(), [10, 3, 10, 10])
+ self.assertEqual(mean.get_shape().as_list(), [3,])
+ self.assertEqual(var.get_shape().as_list(), [3,])
+
+ # case: doesn't need broadcasting
+ g_val = np.random.random((1, 3, 1, 1))
+ b_val = np.random.random((1, 3, 1, 1))
+ gamma = keras.backend.variable(g_val)
+ beta = keras.backend.variable(b_val)
+ normed, mean, var = keras.backend.normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=1e-3)
+ self.assertEqual(normed.get_shape().as_list(), [10, 3, 10, 10])
+ self.assertEqual(mean.get_shape().as_list(), [3,])
+ self.assertEqual(var.get_shape().as_list(), [3,])
+
+ # case: gamma=None
+ gamma = None
+ normed, mean, var = keras.backend.normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=1e-3)
+ self.assertEqual(normed.get_shape().as_list(), [10, 3, 10, 10])
+ self.assertEqual(mean.get_shape().as_list(), [3,])
+ self.assertEqual(var.get_shape().as_list(), [3,])
+
+ # case: beta=None
+ beta = None
+ normed, mean, var = keras.backend.normalize_batch_in_training(
+ x, gamma, beta, reduction_axes, epsilon=1e-3)
+ self.assertEqual(normed.get_shape().as_list(), [10, 3, 10, 10])
+ self.assertEqual(mean.get_shape().as_list(), [3,])
+ self.assertEqual(var.get_shape().as_list(), [3,])
+
+
+class TestCTC(test.TestCase):
+
+ def test_ctc_decode(self):
+ with self.test_session():
+ depth = 6
+ seq_len_0 = 5
+ input_prob_matrix_0 = np.asarray(
+ [[0.30999, 0.309938, 0.0679938, 0.0673362, 0.0708352, 0.173908],
+ [0.215136, 0.439699, 0.0370931, 0.0393967, 0.0381581, 0.230517],
+ [0.199959, 0.489485, 0.0233221, 0.0251417, 0.0233289, 0.238763],
+ [0.279611, 0.452966, 0.0204795, 0.0209126, 0.0194803, 0.20655],
+ [0.51286, 0.288951, 0.0243026, 0.0220788, 0.0219297, 0.129878],
+ # Random entry added in at time=5
+ [0.155251, 0.164444, 0.173517, 0.176138, 0.169979, 0.160671]],
+ dtype=np.float32)
+
+ # len max_time_steps array of batch_size x depth matrices
+ inputs = ([input_prob_matrix_0[t, :][np.newaxis, :]
+ for t in range(seq_len_0)] + # Pad to max_time_steps = 8
+ 2 * [np.zeros((1, depth), dtype=np.float32)])
+
+ inputs = keras.backend.variable(np.asarray(inputs).transpose((1, 0, 2)))
+
+ # batch_size length vector of sequence_lengths
+ input_length = keras.backend.variable(
+ np.array([seq_len_0], dtype=np.int32))
+ # batch_size length vector of negative log probabilities
+ log_prob_truth = np.array([
+ 0.584855, # output beam 0
+ 0.389139 # output beam 1
+ ], np.float32)[np.newaxis, :]
+
+ decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
+ beam_width = 2
+ top_paths = 2
+
+ decode_pred_tf, log_prob_pred_tf = keras.backend.ctc_decode(
+ inputs,
+ input_length,
+ greedy=False,
+ beam_width=beam_width,
+ top_paths=top_paths)
+
+ self.assertEqual(len(decode_pred_tf), top_paths)
+ log_prob_pred = keras.backend.eval(log_prob_pred_tf)
+ for i in range(top_paths):
+ self.assertTrue(
+ np.alltrue(
+ decode_truth[i] == keras.backend.eval(decode_pred_tf[i])))
+ self.assertAllClose(log_prob_truth, log_prob_pred)
+
+ def test_ctc_batch_cost(self):
+ with self.test_session():
+ label_lens = np.expand_dims(np.asarray([5, 4]), 1)
+ input_lens = np.expand_dims(np.asarray([5, 5]), 1) # number of timesteps
+ loss_log_probs = [3.34211, 5.42262]
+
+ # dimensions are batch x time x categories
+ labels = np.asarray([[0, 1, 2, 1, 0], [0, 1, 1, 0, -1]])
+ inputs = np.asarray(
+ [[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
+ [0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
+ [0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
+ [0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
+ [0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
+ [[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
+ [0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
+ [0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
+ [0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
+ [0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]]],
+ dtype=np.float32)
+
+ labels = keras.backend.variable(labels, dtype='int32')
+ inputs = keras.backend.variable(inputs, dtype='float32')
+ input_lens = keras.backend.variable(input_lens, dtype='int32')
+ label_lens = keras.backend.variable(label_lens, dtype='int32')
+ res = keras.backend.eval(
+ keras.backend.ctc_batch_cost(labels, inputs, input_lens, label_lens))
+ self.assertAllClose(res[:, 0], loss_log_probs, atol=1e-05)
+
+
+class TestRandomOps(test.TestCase):
+
+ def test_random_binomial(self):
+ with self.test_session():
+ np.random.seed(123)
+ x = keras.backend.random_binomial((1000, 1000), p=0.5)
+ self.assertAllClose(np.mean(keras.backend.eval(x)), 0.5, atol=0.1)
+
+ def test_truncated_normal(self):
+ with self.test_session():
+ np.random.seed(123)
+ x = keras.backend.truncated_normal((1000, 1000), mean=0.0, stddev=1.0)
+ y = keras.backend.eval(x)
+ self.assertAllClose(np.mean(y), 0., atol=0.1)
+ self.assertAllClose(np.std(y), 0.88, atol=0.1)
+ self.assertAllClose(np.max(y), 2., atol=0.1)
+ self.assertAllClose(np.min(y), -2., atol=0.1)
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/callbacks.py b/tensorflow/contrib/keras/python/keras/callbacks.py
index 6df6662081..06a5f4ad8f 100644
--- a/tensorflow/contrib/keras/python/keras/callbacks.py
+++ b/tensorflow/contrib/keras/python/keras/callbacks.py
@@ -486,7 +486,7 @@ class EarlyStopping(Callback):
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
- 'fallback to auto mode.' % (self.mode))
+ 'fallback to auto mode.' % mode)
mode = 'auto'
if mode == 'min':
@@ -494,7 +494,7 @@ class EarlyStopping(Callback):
elif mode == 'max':
self.monitor_op = np.greater
else:
- if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
+ if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
@@ -516,7 +516,7 @@ class EarlyStopping(Callback):
logging.warning('Early stopping conditioned on metric `%s` '
'which is not available. Available metrics are: %s' %
(self.monitor, ','.join(list(logs.keys()))))
-
+ return
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
diff --git a/tensorflow/contrib/keras/python/keras/callbacks_test.py b/tensorflow/contrib/keras/python/keras/callbacks_test.py
index 15a7304b60..d8c5c0337f 100644
--- a/tensorflow/contrib/keras/python/keras/callbacks_test.py
+++ b/tensorflow/contrib/keras/python/keras/callbacks_test.py
@@ -35,6 +35,11 @@ try:
except ImportError:
h5py = None
+try:
+ import requests # pylint:disable=g-import-not-at-top
+except ImportError:
+ requests = None
+
TRAIN_SAMPLES = 10
TEST_SAMPLES = 10
@@ -158,6 +163,24 @@ class KerasCallbacksTest(test.TestCase):
assert os.path.exists(filepath)
os.remove(filepath)
+ # Case: metric not available.
+ cbks = [
+ keras.callbacks.ModelCheckpoint(
+ filepath,
+ monitor='unknown',
+ save_best_only=True)
+ ]
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=1,
+ verbose=0)
+ # File won't be written.
+ assert not os.path.exists(filepath)
+
# case 5
save_best_only = False
period = 2
@@ -179,7 +202,7 @@ class KerasCallbacksTest(test.TestCase):
validation_data=(x_test, y_test),
callbacks=cbks,
epochs=4,
- verbose=0)
+ verbose=1)
assert os.path.exists(filepath.format(epoch=1))
assert os.path.exists(filepath.format(epoch=3))
os.remove(filepath.format(epoch=1))
@@ -187,9 +210,16 @@ class KerasCallbacksTest(test.TestCase):
assert not os.path.exists(filepath.format(epoch=0))
assert not os.path.exists(filepath.format(epoch=2))
+ # Invalid use: this will raise a warning but not an Exception.
+ keras.callbacks.ModelCheckpoint(
+ filepath,
+ monitor=monitor,
+ save_best_only=save_best_only,
+ mode='unknown')
+
def test_EarlyStopping(self):
with self.test_session():
- np.random.seed(1337)
+ np.random.seed(123)
(x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
train_samples=TRAIN_SAMPLES,
test_samples=TEST_SAMPLES,
@@ -206,37 +236,28 @@ class KerasCallbacksTest(test.TestCase):
loss='categorical_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
- mode = 'max'
- monitor = 'val_acc'
- patience = 0
- cbks = [
- keras.callbacks.EarlyStopping(
- patience=patience, monitor=monitor, mode=mode)
- ]
- model.fit(
- x_train,
- y_train,
- batch_size=BATCH_SIZE,
- validation_data=(x_test, y_test),
- callbacks=cbks,
- epochs=20,
- verbose=0)
- mode = 'auto'
- monitor = 'val_acc'
- patience = 2
- cbks = [
- keras.callbacks.EarlyStopping(
- patience=patience, monitor=monitor, mode=mode)
+ cases = [
+ ('max', 'val_acc'),
+ ('min', 'val_loss'),
+ ('auto', 'val_acc'),
+ ('auto', 'loss'),
+ ('unknown', 'unknown')
]
- model.fit(
- x_train,
- y_train,
- batch_size=BATCH_SIZE,
- validation_data=(x_test, y_test),
- callbacks=cbks,
- epochs=20,
- verbose=0)
+ for mode, monitor in cases:
+ patience = 0
+ cbks = [
+ keras.callbacks.EarlyStopping(
+ patience=patience, monitor=monitor, mode=mode)
+ ]
+ model.fit(
+ x_train,
+ y_train,
+ batch_size=BATCH_SIZE,
+ validation_data=(x_test, y_test),
+ callbacks=cbks,
+ epochs=5,
+ verbose=0)
def test_EarlyStopping_reuse(self):
with self.test_session():
@@ -260,6 +281,14 @@ class KerasCallbacksTest(test.TestCase):
hist = model.fit(data, labels, callbacks=[stopper], verbose=0)
assert len(hist.epoch) >= patience
+ def test_RemoteMonitor(self):
+ if requests is None:
+ return
+
+ monitor = keras.callbacks.RemoteMonitor()
+ # This will raise a warning since the default address in unreachable:
+ monitor.on_epoch_end(0, logs={'loss': 0.})
+
def test_LearningRateScheduler(self):
with self.test_session():
np.random.seed(1337)
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index 67883bfb24..8f69dbf49c 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import copy
import json
import os
-import re
import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
@@ -734,17 +733,6 @@ class Network(tf_base_layers.Network, Layer):
state_updates += layer.updates
return state_updates
- @property
- def constraints(self):
- cons = {}
- for layer in self.layers:
- for key, value in layer.constraints.items():
- if key in cons and cons[key] != value:
- raise ValueError('Received multiple constraints '
- 'for one weight tensor: ' + str(key))
- cons[key] = value
- return cons
-
def get_weights(self):
"""Retrieves the weights of the model.
@@ -1212,39 +1200,6 @@ def _to_list(x):
return [x]
-def _object_list_uid(object_list):
- object_list = _to_list(object_list)
- return ', '.join([str(abs(id(x))) for x in object_list])
-
-
-def _to_snake_case(name):
- intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
- insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
- # If the class is private the name starts with "_" which is not secure
- # for creating scopes. We prefix the name with "private" in this case.
- if insecure[0] != '_':
- return insecure
- return 'private' + insecure
-
-
-def _collect_input_shape(input_tensors):
- """Collects the output shape(s) of a list of Keras tensors.
-
- Arguments:
- input_tensors: list of input tensors (or single input tensor).
-
- Returns:
- List of shape tuples (or single tuple), one tuple per input.
- """
- input_tensors = _to_list(input_tensors)
- shapes = []
- for x in input_tensors:
- shapes.append(K.int_shape(x))
- if len(shapes) == 1:
- return shapes[0]
- return shapes
-
-
def save_weights_to_hdf5_group(f, layers):
from tensorflow.contrib.keras.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology_test.py b/tensorflow/contrib/keras/python/keras/engine/topology_test.py
index f6c0b8a607..fa099515ab 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology_test.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology_test.py
@@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+import shutil
+
import numpy as np
from tensorflow.contrib.keras.python import keras
@@ -30,6 +33,11 @@ try:
except ImportError:
yaml = None
+try:
+ import h5py # pylint:disable=g-import-not-at-top
+except ImportError:
+ h5py = None
+
class TopologyConstructionTest(test.TestCase):
@@ -92,6 +100,41 @@ class TopologyConstructionTest(test.TestCase):
self.assertListEqual(model.trainable_weights, [])
self.assertListEqual(model.non_trainable_weights, weights)
+ def test_weight_loading(self):
+ with self.test_session():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3)(a)
+ b = keras.layers.Dense(1)(x)
+ model = keras.models.Model(a, b)
+
+ x = np.random.random((3, 2))
+ ref_y = model.predict(x)
+ weights = model.get_weights()
+ model.set_weights(weights)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ with self.assertRaises(ValueError):
+ model.set_weights(weights[1:])
+ with self.assertRaises(ValueError):
+ model.set_weights(weights[::-1])
+
+ if h5py is None:
+ return # Skip rest of test if H5py isn't available.
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ h5_path = os.path.join(temp_dir, 'test.h5')
+ model.save_weights(h5_path)
+ model.load_weights(h5_path)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ model.load_weights(h5_path, by_name=True)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
def test_learning_phase(self):
with self.test_session():
a = keras.layers.Input(shape=(32,), name='input_a')
@@ -154,6 +197,11 @@ class TopologyConstructionTest(test.TestCase):
a = keras.layers.Input(shape=(32,), name='input_a')
b = keras.layers.Input(shape=(32,), name='input_b')
+ with self.assertRaises(ValueError):
+ _ = keras.layers.Input(shape=(32,), batch_shape=(10, 32))
+ with self.assertRaises(ValueError):
+ _ = keras.layers.Input(shape=(32,), unknwon_kwarg=None)
+
self.assertListEqual(a.get_shape().as_list(), [None, 32])
a_layer, a_node_index, a_tensor_index = a._keras_history
b_layer, _, _ = b._keras_history
@@ -498,6 +546,96 @@ class TopologyConstructionTest(test.TestCase):
model = keras.models.Model(a, b)
self.assertEqual(model.output_mask.get_shape().as_list(), [None, 10])
+ def test_weight_preprocessing(self):
+ input_dim = 3
+ output_dim = 3
+ size = 2
+ cases = [
+ [
+ (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))),
+ [np.random.random((2, 1)), np.random.random((2, 1))],
+ (None, 3, 2),
+ ],
+ [
+ (keras.layers.TimeDistributed(keras.layers.Dense(1))),
+ [np.random.random((2, 1)), np.random.random((1,))],
+ (None, 3, 2),
+ ],
+ [
+ (keras.layers.Conv1D(output_dim, size, use_bias=False)),
+ [np.random.random((output_dim, input_dim, size, 1))],
+ (None, 4, input_dim),
+ ],
+ [
+ (keras.layers.Conv2D(output_dim, size,
+ use_bias=False, data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size))],
+ (None, input_dim, 4, 4),
+ ],
+ [
+ (keras.layers.Conv2DTranspose(output_dim, size,
+ use_bias=False,
+ data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size))],
+ (None, input_dim, 4, 4),
+ ],
+ [
+ (keras.layers.Conv2DTranspose(output_dim, size,
+ use_bias=False,
+ data_format='channels_last')),
+ [np.random.random((size, size, input_dim, output_dim))],
+ (None, 4, 4, input_dim),
+ ],
+ [
+ (keras.layers.Conv3D(output_dim, size,
+ use_bias=False, data_format='channels_first')),
+ [np.random.random((output_dim, input_dim, size, size, size))],
+ (None, input_dim, 4, 4, 4),
+ ],
+ [
+ (keras.layers.GRU(output_dim)),
+ [np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,))],
+ (None, 4, input_dim),
+ ],
+ [
+ (keras.layers.LSTM(output_dim)),
+ [np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,)),
+ np.random.random((input_dim, output_dim)),
+ np.random.random((output_dim, output_dim)),
+ np.random.random((output_dim,))],
+ (None, 4, input_dim),
+ ],
+ ]
+ for layer, weights, input_shape in cases:
+ layer.build(input_shape)
+ _ = keras.engine.topology.preprocess_weights_for_loading(
+ layer, weights, original_keras_version='1')
+
+ model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)])
+ _ = keras.engine.topology.preprocess_weights_for_loading(
+ model, model.weights, original_keras_version='1')
+
+ x = keras.Input((2,))
+ y = keras.layers.Dense(2)(x)
+ model = keras.models.Model(x, y)
+ _ = keras.engine.topology.preprocess_weights_for_loading(
+ model, model.weights, original_keras_version='1')
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/engine/training.py b/tensorflow/contrib/keras/python/keras/engine/training.py
index 8269913ea1..fabfa537d8 100644
--- a/tensorflow/contrib/keras/python/keras/engine/training.py
+++ b/tensorflow/contrib/keras/python/keras/engine/training.py
@@ -668,8 +668,7 @@ class Model(Container):
'Output "' + name + '" missing from loss dictionary. '
'We assume this was done on purpose, '
'and we will not be expecting '
- 'any data to be passed to "' + name + '" during training.',
- stacklevel=2)
+ 'any data to be passed to "' + name + '" during training.')
loss_functions.append(losses.get(loss.get(name)))
elif isinstance(loss, list):
if len(loss) != len(self.outputs):
diff --git a/tensorflow/contrib/keras/python/keras/engine/training_test.py b/tensorflow/contrib/keras/python/keras/engine/training_test.py
index d2aac54c94..ad6812ddaf 100644
--- a/tensorflow/contrib/keras/python/keras/engine/training_test.py
+++ b/tensorflow/contrib/keras/python/keras/engine/training_test.py
@@ -162,6 +162,41 @@ class TrainingTest(test.TestCase):
batch_size=5,
verbose=0)
+ # Invalid use cases
+ with self.assertRaises(ValueError):
+ model.train_on_batch({'input_a': input_a_np},
+ [output_d_np, output_e_np])
+ with self.assertRaises(TypeError):
+ model.fit(
+ [input_a_np, input_b_np], [output_d_np, output_e_np],
+ epochs=1,
+ validation_data=([input_a_np, input_b_np], 0, 0),
+ verbose=0)
+ with self.assertRaises(ValueError):
+ model.train_on_batch([input_a_np], [output_d_np, output_e_np])
+ with self.assertRaises(TypeError):
+ model.train_on_batch(1, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ model.train_on_batch(input_a_np, [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_input = np.random.random((11, 3))
+ model.train_on_batch([bad_input, input_b_np],
+ [output_d_np, output_e_np])
+ with self.assertRaises(ValueError):
+ bad_target = np.random.random((11, 4))
+ model.train_on_batch([input_a_np, input_b_np],
+ [bad_target, output_e_np])
+
+ # Build single-input model
+ x = keras.layers.Input(shape=(3,), name='input_a')
+ y = keras.layers.Dense(4)(x)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ # This will work
+ model.fit([input_a_np], output_d_np, epochs=1)
+ with self.assertRaises(ValueError):
+ model.fit([input_a_np, input_a_np], output_d_np, epochs=1)
+
def test_evaluate_predict_on_arrays(self):
with self.test_session():
a = keras.layers.Input(shape=(3,), name='input_a')
@@ -240,6 +275,40 @@ class TrainingTest(test.TestCase):
})
self.assertEqual(len(out), 2)
+ def test_invalid_loss_or_metrics(self):
+ num_classes = 5
+ train_samples = 1000
+ test_samples = 1000
+ input_dim = 5
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(10, input_shape=(input_dim,)))
+ model.add(keras.layers.Activation('relu'))
+ model.add(keras.layers.Dense(num_classes))
+ model.add(keras.layers.Activation('softmax'))
+ model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
+ np.random.seed(1337)
+ (x_train, y_train), (_, _) = testing_utils.get_test_data(
+ train_samples=train_samples,
+ test_samples=test_samples,
+ input_shape=(input_dim,),
+ num_classes=num_classes)
+ with self.assertRaises(ValueError):
+ model.fit(x_train, y_train)
+
+ with self.assertRaises(ValueError):
+ model.fit(x_train, np.concatenate([y_train, y_train], axis=-1))
+
+ with self.assertRaises(TypeError):
+ model.compile(loss='categorical_crossentropy',
+ optimizer='rmsprop',
+ metrics=set(0))
+
+ with self.assertRaises(RuntimeError):
+ model.compile(loss=None,
+ optimizer='rmsprop')
+
class LossWeightingTest(test.TestCase):
@@ -463,7 +532,7 @@ class LossWeightingTest(test.TestCase):
temporal_x_test[test_ids], temporal_y_test[test_ids], verbose=0)
self.assertLess(score, ref_score)
- def test_class_weight_wrong_classes(self):
+ def test_class_weight_invalid_use_case(self):
num_classes = 5
train_samples = 1000
test_samples = 1000
@@ -495,6 +564,44 @@ class LossWeightingTest(test.TestCase):
model.fit(x_train, y_train,
epochs=0, verbose=0, class_weight=class_weight)
+ with self.assertRaises(ValueError):
+ model.compile(
+ loss='binary_crossentropy',
+ optimizer='rmsprop',
+ sample_weight_mode=[])
+
+ # Build multi-output model
+ x = keras.Input((3,))
+ y1 = keras.layers.Dense(4, name='1')(x)
+ y2 = keras.layers.Dense(4, name='2')(x)
+ model = keras.models.Model(x, [y1, y2])
+ model.compile(optimizer='rmsprop', loss='mse')
+ x_np = np.random.random((10, 3))
+ y_np = np.random.random((10, 4))
+ w_np = np.random.random((10,))
+ # This will work
+ model.fit(x_np, [y_np, y_np], epochs=1,
+ sample_weight={'1': w_np})
+ # These will not
+ with self.assertRaises(ValueError):
+ model.fit(x_np, [y_np, y_np], epochs=1,
+ sample_weight=[w_np])
+ with self.assertRaises(TypeError):
+ model.fit(x_np, [y_np, y_np], epochs=1,
+ sample_weight=w_np)
+ with self.assertRaises(ValueError):
+ bad_w_np = np.random.random((11,))
+ model.fit(x_np, [y_np, y_np], epochs=1,
+ sample_weight={'1': bad_w_np})
+ with self.assertRaises(ValueError):
+ bad_w_np = np.random.random((10, 2))
+ model.fit(x_np, [y_np, y_np], epochs=1,
+ sample_weight={'1': bad_w_np})
+ with self.assertRaises(ValueError):
+ bad_w_np = np.random.random((10, 2, 2))
+ model.fit(x_np, [y_np, y_np], epochs=1,
+ sample_weight={'1': bad_w_np})
+
class LossMaskingTest(test.TestCase):
@@ -664,8 +771,8 @@ class TestDynamicTrainability(test.TestCase):
class TestGeneratorMethods(test.TestCase):
def test_generator_methods(self):
- arr_data = np.random.randint(0, 256, (50, 2))
- arr_labels = np.random.randint(0, 2, 50)
+ arr_data = np.random.random((50, 2))
+ arr_labels = np.random.random((50,))
def custom_generator():
batch_size = 10
@@ -678,49 +785,200 @@ class TestGeneratorMethods(test.TestCase):
y = arr_labels[start: end]
yield x, y
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(2,)))
- model.compile(loss='mse', optimizer='sgd')
-
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
- max_queue_size=10,
- workers=4,
- use_multiprocessing=True)
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
- max_queue_size=10,
- use_multiprocessing=False)
- model.fit_generator(custom_generator(),
- steps_per_epoch=5,
- epochs=1,
- verbose=1,
- max_queue_size=10,
- use_multiprocessing=False,
- validation_data=custom_generator(),
- validation_steps=10)
- model.predict_generator(custom_generator(),
- steps=5,
+ with self.test_session():
+ x = keras.Input((2,))
+ y = keras.layers.Dense(1)(x)
+ fn_model = keras.models.Model(x, y)
+ fn_model.compile(loss='mse', optimizer='sgd')
+
+ seq_model = keras.models.Sequential()
+ seq_model.add(keras.layers.Dense(1, input_shape=(2,)))
+ seq_model.compile(loss='mse', optimizer='sgd')
+
+ for model in [fn_model, seq_model]:
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
max_queue_size=10,
- workers=2,
+ workers=4,
use_multiprocessing=True)
- model.predict_generator(custom_generator(),
- steps=5,
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False,
+ validation_data=custom_generator(),
+ validation_steps=10)
+ model.predict_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ workers=2,
+ use_multiprocessing=True)
+ model.predict_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ workers=2,
+ use_multiprocessing=True)
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+
+ # Test legacy API
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_q_size=10,
+ workers=4,
+ pickle_safe=True)
+ model.predict_generator(custom_generator(),
+ steps=5,
+ max_q_size=10,
+ workers=2,
+ pickle_safe=True)
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_q_size=10,
+ workers=2,
+ pickle_safe=True)
+
+ def test_generator_methods_with_sample_weights(self):
+ arr_data = np.random.random((50, 2))
+ arr_labels = np.random.random((50,))
+ arr_sample_weights = np.random.random((50,))
+
+ def custom_generator():
+ batch_size = 10
+ n_samples = 50
+ while True:
+ batch_index = np.random.randint(0, n_samples - batch_size)
+ start = batch_index
+ end = start + batch_size
+ x = arr_data[start: end]
+ y = arr_labels[start: end]
+ w = arr_sample_weights[start: end]
+ yield x, y, w
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(2,)))
+ model.compile(loss='mse', optimizer='sgd')
+
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False,
+ validation_data=custom_generator(),
+ validation_steps=10)
+ model.predict_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+
+ def test_generator_methods_invalid_use_case(self):
+
+ def custom_generator():
+ while 1:
+ yield 0
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(2,)))
+ model.compile(loss='mse', optimizer='sgd')
+
+ with self.assertRaises(ValueError):
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
max_queue_size=10,
use_multiprocessing=False)
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- workers=2,
- use_multiprocessing=True)
- model.evaluate_generator(custom_generator(),
- steps=5,
- max_queue_size=10,
- use_multiprocessing=False)
+ with self.assertRaises(ValueError):
+ model.fit_generator(custom_generator(),
+ steps_per_epoch=5,
+ epochs=1,
+ verbose=1,
+ max_queue_size=10,
+ use_multiprocessing=False,
+ validation_data=custom_generator(),
+ validation_steps=10)
+ with self.assertRaises(TypeError):
+ model.predict_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+ with self.assertRaises(ValueError):
+ model.evaluate_generator(custom_generator(),
+ steps=5,
+ max_queue_size=10,
+ use_multiprocessing=False)
+
+
+class TestTrainingUtils(test.TestCase):
+
+ def test_check_array_lengths(self):
+ keras.engine.training._check_array_lengths(None, None, None)
+ a_np = np.random.random((4, 3, 3))
+ keras.engine.training._check_array_lengths(a_np, a_np, a_np)
+ keras.engine.training._check_array_lengths(
+ [a_np, a_np], [a_np, a_np], [a_np, a_np])
+ keras.engine.training._check_array_lengths([None], [None], [None])
+
+ b_np = np.random.random((3, 4))
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths(a_np, None, None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths(a_np, a_np, None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths([a_np], [None], None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths([a_np], [b_np], None)
+ with self.assertRaises(ValueError):
+ keras.engine.training._check_array_lengths([a_np], None, [b_np])
+
+ def test_slice_arrays(self):
+ input_a = np.random.random((10, 3))
+ keras.engine.training._slice_arrays(None)
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
+ input_a = [None, [1, 1], None, [1, 1]]
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
+ input_a = [None]
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
+ input_a = None
+ keras.engine.training._slice_arrays(input_a, 0)
+ keras.engine.training._slice_arrays(input_a, 0, 1)
+ keras.engine.training._slice_arrays(input_a, stop=2)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_test.py b/tensorflow/contrib/keras/python/keras/layers/convolutional_test.py
index fbea2e12d7..00a7fbf8fb 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional_test.py
+++ b/tensorflow/contrib/keras/python/keras/layers/convolutional_test.py
@@ -484,6 +484,12 @@ class ZeroPaddingTest(test.TestCase):
np.testing.assert_allclose(np_output[:, 1:-2, :], 1.)
layer.get_config()
+ # test incorrect use
+ with self.assertRaises(ValueError):
+ keras.layers.ZeroPadding1D(padding=(1, 1, 1))
+ with self.assertRaises(ValueError):
+ keras.layers.ZeroPadding1D(padding=None)
+
def test_zero_padding_2d(self):
num_samples = 2
stack_size = 2
@@ -550,6 +556,12 @@ class ZeroPaddingTest(test.TestCase):
np.testing.assert_allclose(np_output[:, :, :, right_offset], 0.)
np.testing.assert_allclose(np_output[:, :, 1:-2, 3:-4], 1.)
+ # test incorrect use
+ with self.assertRaises(ValueError):
+ keras.layers.ZeroPadding2D(padding=(1, 1, 1))
+ with self.assertRaises(ValueError):
+ keras.layers.ZeroPadding2D(padding=None)
+
def test_zero_padding_3d(self):
num_samples = 2
stack_size = 2
@@ -579,6 +591,12 @@ class ZeroPaddingTest(test.TestCase):
np.testing.assert_allclose(np_output[:, :, :, offset, :], 0.)
np.testing.assert_allclose(np_output[:, 2:-2, 2:-2, 2:-2, :], 1.)
+ # test incorrect use
+ with self.assertRaises(ValueError):
+ keras.layers.ZeroPadding3D(padding=(1, 1))
+ with self.assertRaises(ValueError):
+ keras.layers.ZeroPadding3D(padding=None)
+
class UpSamplingTest(test.TestCase):
@@ -701,6 +719,12 @@ class CroppingTest(test.TestCase):
kwargs={'cropping': (2, 2)},
input_shape=inputs.shape)
+ # test incorrect use
+ with self.assertRaises(ValueError):
+ keras.layers.Cropping1D(cropping=(1, 1, 1))
+ with self.assertRaises(ValueError):
+ keras.layers.Cropping1D(cropping=None)
+
def test_cropping_2d(self):
num_samples = 2
stack_size = 2
@@ -756,48 +780,62 @@ class CroppingTest(test.TestCase):
# compare with input
np.testing.assert_allclose(np_output, inputs)
+ # test incorrect use
+ with self.assertRaises(ValueError):
+ keras.layers.Cropping2D(cropping=(1, 1, 1))
+ with self.assertRaises(ValueError):
+ keras.layers.Cropping2D(cropping=None)
+
def test_cropping_3d(self):
num_samples = 2
stack_size = 2
input_len_dim1 = 8
input_len_dim2 = 8
input_len_dim3 = 8
- cropping = ((2, 2), (1, 1), (2, 3))
+ croppings = [((2, 2), (1, 1), (2, 3)), 3, (0, 1, 1)]
- for data_format in ['channels_last', 'channels_first']:
- if data_format == 'channels_first':
- inputs = np.random.rand(num_samples, stack_size, input_len_dim1,
- input_len_dim2, input_len_dim3)
- else:
- inputs = np.random.rand(num_samples, input_len_dim1, input_len_dim2,
- input_len_dim3, stack_size)
- # basic test
- with self.test_session(use_gpu=True):
- testing_utils.layer_test(
- keras.layers.Cropping3D,
- kwargs={'cropping': cropping,
- 'data_format': data_format},
- input_shape=inputs.shape)
- # correctness test
- with self.test_session(use_gpu=True):
- layer = keras.layers.Cropping3D(
- cropping=cropping, data_format=data_format)
- layer.build(inputs.shape)
- output = layer(keras.backend.variable(inputs))
- np_output = keras.backend.eval(output)
- # compare with numpy
+ for cropping in croppings:
+ for data_format in ['channels_last', 'channels_first']:
if data_format == 'channels_first':
- expected_out = inputs[:, :,
- cropping[0][0]:-cropping[0][1],
- cropping[1][0]:-cropping[1][1],
- cropping[2][0]:-cropping[2][1]]
+ inputs = np.random.rand(num_samples, stack_size, input_len_dim1,
+ input_len_dim2, input_len_dim3)
else:
- expected_out = inputs[:,
- cropping[0][0]:-cropping[0][1],
- cropping[1][0]:-cropping[1][1],
- cropping[2][0]:-cropping[2][1], :]
- print(expected_out.shape)
- np.testing.assert_allclose(np_output, expected_out)
+ inputs = np.random.rand(num_samples, input_len_dim1, input_len_dim2,
+ input_len_dim3, stack_size)
+ # basic test
+ with self.test_session(use_gpu=True):
+ testing_utils.layer_test(
+ keras.layers.Cropping3D,
+ kwargs={'cropping': cropping,
+ 'data_format': data_format},
+ input_shape=inputs.shape)
+
+ if len(croppings) == 3 and len(croppings[0]) == 2:
+ # correctness test
+ with self.test_session(use_gpu=True):
+ layer = keras.layers.Cropping3D(
+ cropping=cropping, data_format=data_format)
+ layer.build(inputs.shape)
+ output = layer(keras.backend.variable(inputs))
+ np_output = keras.backend.eval(output)
+ # compare with numpy
+ if data_format == 'channels_first':
+ expected_out = inputs[:, :,
+ cropping[0][0]:-cropping[0][1],
+ cropping[1][0]:-cropping[1][1],
+ cropping[2][0]:-cropping[2][1]]
+ else:
+ expected_out = inputs[:,
+ cropping[0][0]:-cropping[0][1],
+ cropping[1][0]:-cropping[1][1],
+ cropping[2][0]:-cropping[2][1], :]
+ np.testing.assert_allclose(np_output, expected_out)
+
+ # test incorrect use
+ with self.assertRaises(ValueError):
+ keras.layers.Cropping3D(cropping=(1, 1))
+ with self.assertRaises(ValueError):
+ keras.layers.Cropping3D(cropping=None)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py
index af5e2fd46e..c3df1c85d7 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core.py
@@ -207,12 +207,9 @@ class SpatialDropout2D(Dropout):
def _get_noise_shape(self, inputs):
input_shape = K.shape(inputs)
if self.data_format == 'channels_first':
- noise_shape = (input_shape[0], input_shape[1], 1, 1)
+ return (input_shape[0], input_shape[1], 1, 1)
elif self.data_format == 'channels_last':
- noise_shape = (input_shape[0], 1, 1, input_shape[3])
- else:
- raise ValueError('Invalid data_format:', self.data_format)
- return noise_shape
+ return (input_shape[0], 1, 1, input_shape[3])
class SpatialDropout3D(Dropout):
@@ -262,12 +259,9 @@ class SpatialDropout3D(Dropout):
def _get_noise_shape(self, inputs):
input_shape = K.shape(inputs)
if self.data_format == 'channels_first':
- noise_shape = (input_shape[0], input_shape[1], 1, 1, 1)
+ return (input_shape[0], input_shape[1], 1, 1, 1)
elif self.data_format == 'channels_last':
- noise_shape = (input_shape[0], 1, 1, 1, input_shape[4])
- else:
- raise ValueError('Invalid data_format:', self.data_format)
- return noise_shape
+ return (input_shape[0], 1, 1, 1, input_shape[4])
class Activation(Layer):
diff --git a/tensorflow/contrib/keras/python/keras/layers/core_test.py b/tensorflow/contrib/keras/python/keras/layers/core_test.py
index e0cc205591..818c55afe4 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core_test.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core_test.py
@@ -56,6 +56,24 @@ class CoreLayersTest(test.TestCase):
kwargs={'rate': 0.5},
input_shape=(2, 3, 4, 5))
+ with self.test_session():
+ testing_utils.layer_test(
+ keras.layers.SpatialDropout2D,
+ kwargs={'rate': 0.5, 'data_format': 'channels_first'},
+ input_shape=(2, 3, 4, 5))
+
+ with self.test_session():
+ testing_utils.layer_test(
+ keras.layers.SpatialDropout3D,
+ kwargs={'rate': 0.5},
+ input_shape=(2, 3, 4, 4, 5))
+
+ with self.test_session():
+ testing_utils.layer_test(
+ keras.layers.SpatialDropout3D,
+ kwargs={'rate': 0.5, 'data_format': 'channels_first'},
+ input_shape=(2, 3, 4, 4, 5))
+
def test_activation(self):
# with string argument
with self.test_session():
@@ -185,6 +203,7 @@ class CoreLayersTest(test.TestCase):
layer = keras.layers.ActivityRegularization(l1=0.1)
layer(keras.backend.variable(np.ones((2, 4))))
self.assertEqual(1, len(layer.losses))
+ _ = layer.get_config()
if __name__ == '__main__':
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py
index 84c03fdebd..64d0c40e61 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge.py
+++ b/tensorflow/contrib/keras/python/keras/layers/merge.py
@@ -172,7 +172,7 @@ class _Merge(Layer):
else:
return self._merge_function(inputs)
- def compute_output_shape(self, input_shape):
+ def _compute_output_shape(self, input_shape):
if input_shape[0] is None:
output_shape = None
else:
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge_test.py b/tensorflow/contrib/keras/python/keras/layers/merge_test.py
index 2887fb851b..4a365c2c44 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge_test.py
+++ b/tensorflow/contrib/keras/python/keras/layers/merge_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.keras.python import keras
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -53,12 +54,20 @@ class MergeLayersTest(test.TestCase):
mask = layer.output_mask
self.assertListEqual(mask.get_shape().as_list(), [None, 4])
+ # test missing shape
+ i1 = array_ops.placeholder(shape=(4, None), dtype='float32')
+ i2 = array_ops.placeholder(shape=(4, 5), dtype='float32')
+ layer = keras.layers.Add()
+ o = layer([i1, i2])
+
def test_merge_elementwise_errors(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 6))
with self.assertRaises(ValueError):
keras.layers.add([i1, i2])
with self.assertRaises(ValueError):
+ keras.layers.add([i1])
+ with self.assertRaises(ValueError):
keras.layers.add(i1)
with self.assertRaises(ValueError):
keras.layers.add([i1])
@@ -121,6 +130,14 @@ class MergeLayersTest(test.TestCase):
self.assertEqual(out.shape, (2, 8, 5))
self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4)
+ # test masking
+ m1 = keras.layers.Masking()(i1)
+ layer = keras.layers.Concatenate()
+ o = layer([m1, i2])
+ self.assertListEqual(o.get_shape().as_list(), [None, 4, 10])
+ mask = layer.output_mask
+ self.assertListEqual(mask.get_shape().as_list(), [None, 4])
+
def test_concatenate_errors(self):
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(3, 5))
@@ -138,6 +155,7 @@ class MergeLayersTest(test.TestCase):
o = keras.layers.dot([i1, i2], axes=1)
self.assertListEqual(o.get_shape().as_list(), [None, 1])
model = keras.models.Model([i1, i2], o)
+ _ = keras.layers.Dot(axes=1).get_config()
x1 = np.random.random((2, 4))
x2 = np.random.random((2, 4))
@@ -172,6 +190,9 @@ class MergeLayersTest(test.TestCase):
keras.layers.dot([i1], axes=-1)
with self.assertRaises(ValueError):
keras.layers.dot([i1, i2, i3], axes=-1)
+ with self.assertRaises(ValueError):
+ dot = keras.layers.Dot(1)
+ dot._compute_output_shape(1)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py b/tensorflow/contrib/keras/python/keras/layers/normalization_test.py
index 1a0686800e..eaeafb0c62 100644
--- a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py
+++ b/tensorflow/contrib/keras/python/keras/layers/normalization_test.py
@@ -25,9 +25,9 @@ from tensorflow.contrib.keras.python.keras import testing_utils
from tensorflow.python.platform import test
-class NoiseLayersTest(test.TestCase):
+class NormalizationLayersTest(test.TestCase):
- def basic_batchnorm_test(self):
+ def test_basic_batchnorm(self):
with self.test_session():
testing_utils.layer_test(
keras.layers.BatchNormalization,
@@ -53,7 +53,7 @@ class NoiseLayersTest(test.TestCase):
'center': False},
input_shape=(3, 3))
- def batchnorm_weights_test(self):
+ def test_batchnorm_weights(self):
with self.test_session():
layer = keras.layers.BatchNormalization(scale=False, center=False)
layer.build((None, 3, 4))
@@ -65,16 +65,18 @@ class NoiseLayersTest(test.TestCase):
self.assertEqual(len(layer.trainable_weights), 2)
self.assertEqual(len(layer.weights), 4)
- def batchnorm_regularization_test(self):
+ def test_batchnorm_regularization(self):
with self.test_session():
layer = keras.layers.BatchNormalization(
gamma_regularizer='l1', beta_regularizer='l1')
layer.build((None, 3, 4))
self.assertEqual(len(layer.losses), 2)
+ max_norm = keras.constraints.max_norm
layer = keras.layers.BatchNormalization(
- gamma_constraint='l1', beta_constraint='l1')
+ gamma_constraint=max_norm, beta_constraint=max_norm)
layer.build((None, 3, 4))
- self.assertEqual(len(layer.constraints), 2)
+ self.assertEqual(layer.gamma.constraint, max_norm)
+ self.assertEqual(layer.beta.constraint, max_norm)
def test_batchnorm_correctness(self):
with self.test_session():
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
index 91614c288d..aee02f432e 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py
+++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py
@@ -204,7 +204,6 @@ class TimeDistributed(Wrapper):
step,
inputs,
initial_states=[],
- input_length=input_shape[1],
unroll=False)
y = outputs
else:
@@ -292,7 +291,7 @@ class Bidirectional(Wrapper):
self.backward_layer.set_weights(weights[nw // 2:])
def _compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
if self.merge_mode in ['sum', 'ave', 'mul']:
return self.forward_layer._compute_output_shape(input_shape) # pylint: disable=protected-access
elif self.merge_mode == 'concat':
diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers_test.py b/tensorflow/contrib/keras/python/keras/layers/wrappers_test.py
index d4cd1ccbb4..531fa76dd8 100644
--- a/tensorflow/contrib/keras/python/keras/layers/wrappers_test.py
+++ b/tensorflow/contrib/keras/python/keras/layers/wrappers_test.py
@@ -33,7 +33,6 @@ class TimeDistributedTest(test.TestCase):
model.add(
keras.layers.TimeDistributed(
keras.layers.Dense(2), input_shape=(3, 4)))
- model.add(keras.layers.Activation('relu'))
model.compile(optimizer='rmsprop', loss='mse')
model.fit(
np.random.random((10, 3, 4)),
@@ -44,6 +43,19 @@ class TimeDistributedTest(test.TestCase):
# test config
model.get_config()
+ def test_timedistributed_static_batch_size(self):
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(2), input_shape=(3, 4), batch_size=10))
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.fit(
+ np.random.random((10, 3, 4)),
+ np.random.random((10, 3, 2)),
+ epochs=1,
+ batch_size=10)
+
def test_timedistributed_conv2d(self):
# test with Conv2D
with self.test_session():
@@ -77,31 +89,6 @@ class TimeDistributedTest(test.TestCase):
epochs=1,
batch_size=10)
- def test_timedistributed_sequential(self):
- # test wrapping Sequential model
- with self.test_session():
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(3, input_dim=2))
- outer_model = keras.models.Sequential()
- outer_model.add(keras.layers.TimeDistributed(model, input_shape=(3, 2)))
- outer_model.compile(optimizer='rmsprop', loss='mse')
- outer_model.fit(
- np.random.random((10, 3, 2)),
- np.random.random((10, 3, 3)),
- epochs=1,
- batch_size=10)
-
- # test with functional API
- x = keras.layers.Input(shape=(3, 2))
- y = keras.layers.TimeDistributed(model)(x)
- outer_model = keras.models.Model(x, y)
- outer_model.compile(optimizer='rmsprop', loss='mse')
- outer_model.fit(
- np.random.random((10, 3, 2)),
- np.random.random((10, 3, 3)),
- epochs=1,
- batch_size=10)
-
def test_regularizers(self):
with self.test_session():
model = keras.models.Sequential()
@@ -133,7 +120,7 @@ class BidirectionalTest(test.TestCase):
timesteps = 2
output_dim = 2
with self.test_session():
- for mode in ['sum', 'concat']:
+ for mode in ['sum', 'concat', 'ave', 'mul']:
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
@@ -146,11 +133,35 @@ class BidirectionalTest(test.TestCase):
model.compile(loss='mse', optimizer='sgd')
model.fit(x, y, epochs=1, batch_size=1)
+ # test compute output shape
+ ref_shape = model.layers[-1].output.get_shape()
+ shape = model.layers[-1]._compute_output_shape(
+ (None, timesteps, dim))
+ self.assertListEqual(shape.as_list(), ref_shape.as_list())
+
# test config
model.get_config()
model = keras.models.model_from_json(model.to_json())
model.summary()
+ def test_bidirectional_weight_loading(self):
+ rnn = keras.layers.SimpleRNN
+ samples = 2
+ dim = 2
+ timesteps = 2
+ output_dim = 2
+ with self.test_session():
+ x = np.random.random((samples, timesteps, dim))
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Bidirectional(
+ rnn(output_dim), input_shape=(timesteps, dim)))
+ y_ref = model.predict(x)
+ weights = model.layers[-1].get_weights()
+ model.layers[-1].set_weights(weights)
+ y = model.predict(x)
+ self.assertAllClose(y, y_ref)
+
def test_bidirectional_stacked(self):
# test stacked bidirectional layers
rnn = keras.layers.SimpleRNN
diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py
index c01d2c45e0..1a0d95c7ff 100644
--- a/tensorflow/contrib/keras/python/keras/models.py
+++ b/tensorflow/contrib/keras/python/keras/models.py
@@ -167,18 +167,8 @@ def save_model(model, filepath, overwrite=True, include_optimizer=True):
optimizer_weights_group = f.create_group('optimizer_weights')
weight_values = K.batch_get_value(symbolic_weights)
weight_names = []
- for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
- # Default values of symbolic_weights is /variable for theano
- if K.backend() == 'theano':
- if hasattr(w, 'name') and w.name != '/variable':
- name = str(w.name)
- else:
- name = 'param_' + str(i)
- else:
- if hasattr(w, 'name') and w.name:
- name = str(w.name)
- else:
- name = 'param_' + str(i)
+ for w, val in zip(symbolic_weights, weight_values):
+ name = str(w.name)
weight_names.append(name.encode('utf8'))
optimizer_weights_group.attrs['weight_names'] = weight_names
for name, val in zip(weight_names, weight_values):
@@ -664,12 +654,6 @@ class Sequential(Model):
self.build()
return self.model.regularizers
- @property
- def constraints(self):
- if self.model is None:
- self.build()
- return self.model.constraints
-
def get_weights(self):
"""Retrieves the weights of the model.
diff --git a/tensorflow/contrib/keras/python/keras/models_test.py b/tensorflow/contrib/keras/python/keras/models_test.py
index f7246097ee..44088a1b32 100644
--- a/tensorflow/contrib/keras/python/keras/models_test.py
+++ b/tensorflow/contrib/keras/python/keras/models_test.py
@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
import os
+import shutil
import tempfile
import numpy as np
from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
+from tensorflow.python.training import training as training_module
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -147,6 +149,23 @@ class TestModelSaving(test.TestCase):
model = keras.models.load_model(fname)
os.remove(fname)
+ def test_saving_with_tf_optimizer(self):
+ if h5py is None:
+ return # Skip test if models cannot be saved.
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ model.compile(loss='mse',
+ optimizer=training_module.AdadeltaOptimizer(0.1),
+ metrics=['acc'])
+
+ _, fname = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, fname)
+ model = keras.models.load_model(fname)
+ os.remove(fname)
+
def test_saving_right_after_compilation(self):
if h5py is None:
return # Skip test if models cannot be saved.
@@ -189,6 +208,16 @@ class TestSequential(test.TestCase):
"""Most Sequential model API tests are covered in `training_test.py`.
"""
+ def test_basic_methods(self):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_dim=2))
+ model.add(keras.layers.Dropout(0.3, name='dp'))
+ model.add(keras.layers.Dense(2, kernel_regularizer='l2',
+ kernel_constraint='max_norm'))
+ model.build()
+ self.assertEqual(model.state_updates, model.model.state_updates)
+ self.assertEqual(model.get_layer(name='dp').name, 'dp')
+
def test_sequential_pop(self):
num_hidden = 5
input_dim = 3
@@ -209,6 +238,83 @@ class TestSequential(test.TestCase):
y = np.random.random((batch_size, num_hidden))
model.fit(x, y, epochs=1)
+ # Test popping single-layer model
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.pop()
+ self.assertEqual(len(model.layers), 0)
+ self.assertEqual(len(model.outputs), 0)
+
+ # Invalid use case
+ model = keras.models.Sequential()
+ with self.assertRaises(TypeError):
+ model.pop()
+
+ def test_sequential_weight_loading(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+ h5_path = os.path.join(temp_dir, 'test.h5')
+
+ num_hidden = 5
+ input_dim = 3
+ batch_size = 5
+ num_classes = 2
+
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.add(keras.layers.Dense(num_classes))
+
+ x = np.random.random((batch_size, input_dim))
+ ref_y = model.predict(x)
+
+ model.save_weights(h5_path)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(num_hidden, input_dim=input_dim))
+ model.add(keras.layers.Dense(num_classes))
+ model.load_weights(h5_path)
+ y = model.predict(x)
+
+ self.assertAllClose(y, ref_y)
+
+ def test_invalid_use_cases(self):
+ with self.test_session():
+ # Added objects must be layer instances
+ with self.assertRaises(TypeError):
+ model = keras.models.Sequential()
+ model.add(None)
+
+ # Added layers must have an inputs shape
+ with self.assertRaises(ValueError):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1))
+
+ # Added layers cannot have multiple outputs
+ class MyLayer(keras.layers.Layer):
+
+ def call(self, inputs):
+ return [3 * inputs, 2 * inputs]
+
+ def _compute_output_shape(self, input_shape):
+ return [input_shape, input_shape]
+
+ with self.assertRaises(ValueError):
+ model = keras.models.Sequential()
+ model.add(MyLayer(input_shape=(3,)))
+ with self.assertRaises(TypeError):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(1, input_dim=1))
+ model.add(MyLayer())
+
+ # Building empty model
+ model = keras.models.Sequential()
+ with self.assertRaises(TypeError):
+ model.build()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py b/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py
index 94768f5258..bb09ed1ae8 100644
--- a/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py
+++ b/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py
@@ -104,6 +104,16 @@ class TestImage(test.TestCase):
x = np.random.random((32, 10, 10))
generator.flow(np.arange(x.shape[0]))
+ with self.assertRaises(ValueError):
+ generator = keras.preprocessing.image.ImageDataGenerator(
+ data_format='unknown')
+
+ generator = keras.preprocessing.image.ImageDataGenerator(
+ zoom_range=(2, 2))
+ with self.assertRaises(ValueError):
+ generator = keras.preprocessing.image.ImageDataGenerator(
+ zoom_range=(2, 2, 2))
+
def test_image_data_generator_fit(self):
generator = keras.preprocessing.image.ImageDataGenerator(
featurewise_center=True,
@@ -169,6 +179,12 @@ class TestImage(test.TestCase):
im.save(os.path.join(temp_dir, filename))
count += 1
+ # Test image loading util
+ fname = os.path.join(temp_dir, filenames[0])
+ _ = keras.preprocessing.image.load_img(fname)
+ _ = keras.preprocessing.image.load_img(fname, grayscale=True)
+ _ = keras.preprocessing.image.load_img(fname, target_size=(10, 10))
+
# create iterator
generator = keras.preprocessing.image.ImageDataGenerator()
dir_iterator = generator.flow_from_directory(temp_dir)
@@ -177,6 +193,7 @@ class TestImage(test.TestCase):
self.assertEqual(len(dir_iterator.class_indices), num_classes)
self.assertEqual(len(dir_iterator.classes), count)
self.assertEqual(sorted(dir_iterator.filenames), sorted(filenames))
+ _ = dir_iterator.next()
def test_img_utils(self):
if PIL is None:
@@ -214,6 +231,16 @@ class TestImage(test.TestCase):
x = keras.preprocessing.image.img_to_array(img, data_format='channels_last')
self.assertEqual(x.shape, (height, width, 1))
+ def test_img_transforms(self):
+ x = np.random.random((3, 200, 200))
+ _ = keras.preprocessing.image.random_rotation(x, 20)
+ _ = keras.preprocessing.image.random_shift(x, 0.2, 0.2)
+ _ = keras.preprocessing.image.random_shear(x, 2.)
+ _ = keras.preprocessing.image.random_zoom(x, (0.5, 0.5))
+ with self.assertRaises(ValueError):
+ keras.preprocessing.image.random_zoom(x, (0, 0, 0))
+ _ = keras.preprocessing.image.random_channel_shift(x, 2.)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/keras/python/keras/utils/conv_utils.py b/tensorflow/contrib/keras/python/keras/utils/conv_utils.py
index 570a63b606..ea3a70edab 100644
--- a/tensorflow/contrib/keras/python/keras/utils/conv_utils.py
+++ b/tensorflow/contrib/keras/python/keras/utils/conv_utils.py
@@ -21,46 +21,12 @@ from __future__ import print_function
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
+# pylint: disable=unused-import
from tensorflow.contrib.keras.python.keras import backend as K
-
-
-def normalize_tuple(value, n, name):
- """Transforms a single int or iterable of ints into an int tuple.
-
- Arguments:
- value: The value to validate and convert. Could an int, or any iterable
- of ints.
- n: The size of the tuple to be returned.
- name: The name of the argument being validated, e.g. "strides" or
- "kernel_size". This is only used to format error messages.
-
- Returns:
- A tuple of n integers.
-
- Raises:
- ValueError: If something else than an int/long or iterable thereof was
- passed.
- """
- if isinstance(value, int):
- return (value,) * n
- else:
- try:
- value_tuple = tuple(value)
- except TypeError:
- raise ValueError('The `' + name + '` argument must be a tuple of ' +
- str(n) + ' integers. Received: ' + str(value))
- if len(value_tuple) != n:
- raise ValueError('The `' + name + '` argument must be a tuple of ' +
- str(n) + ' integers. Received: ' + str(value))
- for single_value in value_tuple:
- try:
- int(single_value)
- except ValueError:
- raise ValueError('The `' + name + '` argument must be a tuple of ' +
- str(n) + ' integers. Received: ' + str(value) + ' '
- 'including element ' + str(single_value) + ' of type' +
- ' ' + str(type(single_value)))
- return value_tuple
+from tensorflow.python.layers.utils import conv_input_length
+from tensorflow.python.layers.utils import conv_output_length
+from tensorflow.python.layers.utils import deconv_output_length as deconv_length
+from tensorflow.python.layers.utils import normalize_tuple
def normalize_data_format(value):
@@ -104,66 +70,3 @@ def convert_kernel(kernel):
no_flip = (slice(None, None), slice(None, None))
slices[-2:] = no_flip
return np.copy(kernel[slices])
-
-
-def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
- """Determines output length of a convolution given input length.
-
- Arguments:
- input_length: integer.
- filter_size: integer.
- padding: one of "same", "valid", "full".
- stride: integer.
- dilation: dilation rate, integer.
-
- Returns:
- The output length (integer).
- """
- if input_length is None:
- return None
- assert padding in {'same', 'valid', 'full', 'causal'}
- dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
- if padding == 'same':
- output_length = input_length
- elif padding == 'valid':
- output_length = input_length - dilated_filter_size + 1
- elif padding == 'full':
- output_length = input_length + dilated_filter_size - 1
- elif padding == 'causal':
- output_length = input_length
- return (output_length + stride - 1) // stride
-
-
-def conv_input_length(output_length, filter_size, padding, stride):
- """Determines input length of a convolution given output length.
-
- Arguments:
- output_length: integer.
- filter_size: integer.
- padding: one of "same", "valid", "full".
- stride: integer.
-
- Returns:
- The input length (integer).
- """
- if output_length is None:
- return None
- assert padding in {'same', 'valid', 'full'}
- if padding == 'same':
- pad = filter_size // 2
- elif padding == 'valid':
- pad = 0
- elif padding == 'full':
- pad = filter_size - 1
- return (output_length - 1) * stride - 2 * pad + filter_size
-
-
-def deconv_length(dim_size, stride_size, kernel_size, padding):
- if dim_size is None:
- return None
- dim_size *= stride_size
- if padding == 'valid':
- dim_size += max(kernel_size - stride_size, 0)
- elif padding == 'full':
- dim_size -= (stride_size + kernel_size - 2)
- return dim_size
diff --git a/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py b/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py
index 7b73775f46..55d08a34d0 100644
--- a/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py
+++ b/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py
@@ -19,14 +19,75 @@ from __future__ import division
from __future__ import print_function
from itertools import cycle
+import os
+import tarfile
import threading
+import zipfile
import numpy as np
+from six.moves.urllib.parse import urljoin
+from six.moves.urllib.request import pathname2url
from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
+class TestGetFileAndValidateIt(test.TestCase):
+
+ def test_get_file_and_validate_it(self):
+ """Tests get_file from a url, plus extraction and validation.
+ """
+ dest_dir = self.get_temp_dir()
+ orig_dir = self.get_temp_dir()
+
+ text_file_path = os.path.join(orig_dir, 'test.txt')
+ zip_file_path = os.path.join(orig_dir, 'test.zip')
+ tar_file_path = os.path.join(orig_dir, 'test.tar.gz')
+
+ with open(text_file_path, 'w') as text_file:
+ text_file.write('Float like a butterfly, sting like a bee.')
+
+ with tarfile.open(tar_file_path, 'w:gz') as tar_file:
+ tar_file.add(text_file_path)
+
+ with zipfile.ZipFile(zip_file_path, 'w') as zip_file:
+ zip_file.write(text_file_path)
+
+ origin = urljoin('file://', pathname2url(os.path.abspath(tar_file_path)))
+
+ path = keras.utils.data_utils.get_file('test.txt', origin,
+ untar=True, cache_subdir=dest_dir)
+ filepath = path + '.tar.gz'
+ hashval_sha256 = keras.utils.data_utils._hash_file(filepath)
+ hashval_md5 = keras.utils.data_utils._hash_file(filepath, algorithm='md5')
+ path = keras.utils.data_utils.get_file(
+ 'test.txt', origin, md5_hash=hashval_md5,
+ untar=True, cache_subdir=dest_dir)
+ path = keras.utils.data_utils.get_file(
+ filepath, origin, file_hash=hashval_sha256,
+ extract=True, cache_subdir=dest_dir)
+ self.assertTrue(os.path.exists(filepath))
+ self.assertTrue(keras.utils.data_utils.validate_file(filepath,
+ hashval_sha256))
+ self.assertTrue(keras.utils.data_utils.validate_file(filepath, hashval_md5))
+ os.remove(filepath)
+
+ origin = urljoin('file://', pathname2url(os.path.abspath(zip_file_path)))
+
+ hashval_sha256 = keras.utils.data_utils._hash_file(zip_file_path)
+ hashval_md5 = keras.utils.data_utils._hash_file(zip_file_path,
+ algorithm='md5')
+ path = keras.utils.data_utils.get_file(
+ 'test', origin, md5_hash=hashval_md5,
+ extract=True, cache_subdir=dest_dir)
+ path = keras.utils.data_utils.get_file(
+ 'test', origin, file_hash=hashval_sha256,
+ extract=True, cache_subdir=dest_dir)
+ self.assertTrue(os.path.exists(path))
+ self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_sha256))
+ self.assertTrue(keras.utils.data_utils.validate_file(path, hashval_md5))
+
+
class ThreadsafeIter(object):
def __init__(self, it):
diff --git a/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py b/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py
new file mode 100644
index 0000000000..baa9781e71
--- /dev/null
+++ b/tensorflow/contrib/keras/python/keras/utils/io_utils_test.py
@@ -0,0 +1,100 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for io_utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+import numpy as np
+
+from tensorflow.contrib.keras.python import keras
+from tensorflow.python.platform import test
+
+try:
+ import h5py # pylint:disable=g-import-not-at-top
+except ImportError:
+ h5py = None
+
+
+def create_dataset(h5_path='test.h5'):
+ x = np.random.randn(200, 10).astype('float32')
+ y = np.random.randint(0, 2, size=(200, 1))
+ f = h5py.File(h5_path, 'w')
+ # Creating dataset to store features
+ x_dset = f.create_dataset('my_data', (200, 10), dtype='f')
+ x_dset[:] = x
+ # Creating dataset to store labels
+ y_dset = f.create_dataset('my_labels', (200, 1), dtype='i')
+ y_dset[:] = y
+ f.close()
+
+
+class TestIOUtils(test.TestCase):
+
+ def test_HDF5Matrix(self):
+ if h5py is None:
+ return
+
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ h5_path = os.path.join(temp_dir, 'test.h5')
+ create_dataset(h5_path)
+
+ # Instantiating HDF5Matrix for the training set,
+ # which is a slice of the first 150 elements
+ x_train = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_data', start=0, end=150)
+ y_train = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_labels', start=0, end=150)
+
+ # Likewise for the test set
+ x_test = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_data', start=150, end=200)
+ y_test = keras.utils.io_utils.HDF5Matrix(
+ h5_path, 'my_labels', start=150, end=200)
+
+ # HDF5Matrix behave more or less like Numpy matrices
+ # with regard to indexing
+ self.assertEqual(y_train.shape, (150, 1))
+ # But they do not support negative indices, so don't try print(x_train[-1])
+
+ self.assertEqual(y_train.dtype, np.dtype('i'))
+ self.assertEqual(y_train.ndim, 2)
+ self.assertEqual(y_train.size, 150)
+
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(64, input_shape=(10,), activation='relu'))
+ model.add(keras.layers.Dense(1, activation='sigmoid'))
+ model.compile(loss='binary_crossentropy', optimizer='sgd')
+
+ # Note: you have to use shuffle='batch' or False with HDF5Matrix
+ model.fit(x_train, y_train, batch_size=32, shuffle='batch', verbose=False)
+ # test that evalutation and prediction
+ # don't crash and return reasonable results
+ out_pred = model.predict(x_test, batch_size=32, verbose=False)
+ out_eval = model.evaluate(x_test, y_test, batch_size=32, verbose=False)
+
+ self.assertEqual(out_pred.shape, (50, 1))
+ self.assertEqual(out_eval.shape, ())
+ self.assertGreater(out_eval, 0)
+
+
+if __name__ == '__main__':
+ test.main()