aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras/python/keras/testing_utils.py
blob: bf6f661adff4a22626763f207ef91839766da774 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for unit-testing Keras."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.contrib.keras.python import keras
from tensorflow.python.util import tf_inspect


def get_test_data(train_samples,
                  test_samples,
                  input_shape,
                  num_classes):
  """Generates test data to train a model on.

  Arguments:
    train_samples: Integer, how many training samples to generate.
    test_samples: Integer, how many test samples to generate.
    input_shape: Tuple of integers, shape of the inputs.
    num_classes: Integer, number of classes for the data and targets.
      Only relevant if `classification=True`.

  Returns:
    A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
  """
  num_sample = train_samples + test_samples
  templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
  y = np.random.randint(0, num_classes, size=(num_sample,))
  x = np.zeros((num_sample,) + input_shape)
  for i in range(num_sample):
    x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)
  return ((x[:train_samples], y[:train_samples]),
          (x[train_samples:], y[train_samples:]))


def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
               input_data=None, expected_output=None,
               expected_output_dtype=None):
  """Test routine for a layer with a single input and single output.

  Arguments:
    layer_cls: Layer class object.
    kwargs: Optional dictionary of keyword arguments for instantiating the
      layer.
    input_shape: Input shape tuple.
    input_dtype: Data type of the input data.
    input_data: Numpy array of input data.
    expected_output: Shape tuple for the expected shape of the output.
    expected_output_dtype: Data type expected for the output.

  Returns:
    The output data (Numpy array) returned by the layer, for additional
    checks to be done by the calling code.
  """
  if input_data is None:
    assert input_shape
    if not input_dtype:
      input_dtype = 'float32'
    input_data_shape = list(input_shape)
    for i, e in enumerate(input_data_shape):
      if e is None:
        input_data_shape[i] = np.random.randint(1, 4)
    input_data = 10 * np.random.random(input_data_shape)
    if input_dtype[:4] == 'float':
      input_data -= 0.5
    input_data = input_data.astype(input_dtype)
  elif input_shape is None:
    input_shape = input_data.shape
  if input_dtype is None:
    input_dtype = input_data.dtype
  if expected_output_dtype is None:
    expected_output_dtype = input_dtype

  # instantiation
  kwargs = kwargs or {}
  layer = layer_cls(**kwargs)

  # test get_weights , set_weights at layer level
  weights = layer.get_weights()
  layer.set_weights(weights)

  # test and instantiation from weights
  if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
    kwargs['weights'] = weights
    layer = layer_cls(**kwargs)

  # test in functional API
  x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype)
  y = layer(x)
  assert keras.backend.dtype(y) == expected_output_dtype

  # check shape inference
  model = keras.models.Model(x, y)
  expected_output_shape = tuple(
      layer._compute_output_shape(input_shape).as_list())  # pylint: disable=protected-access
  actual_output = model.predict(input_data)
  actual_output_shape = actual_output.shape
  for expected_dim, actual_dim in zip(expected_output_shape,
                                      actual_output_shape):
    if expected_dim is not None:
      assert expected_dim == actual_dim
  if expected_output is not None:
    np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3)

  # test serialization, weight setting at model level
  model_config = model.get_config()
  recovered_model = keras.models.Model.from_config(model_config)
  if model.weights:
    weights = model.get_weights()
    recovered_model.set_weights(weights)
    output = recovered_model.predict(input_data)
    np.testing.assert_allclose(output, actual_output, rtol=1e-3)

  # test training mode (e.g. useful for dropout tests)
  model.compile('rmsprop', 'mse')
  model.train_on_batch(input_data, actual_output)

  # test as first layer in Sequential API
  layer_config = layer.get_config()
  layer_config['batch_input_shape'] = input_shape
  layer = layer.__class__.from_config(layer_config)

  model = keras.models.Sequential()
  model.add(layer)
  actual_output = model.predict(input_data)
  actual_output_shape = actual_output.shape
  for expected_dim, actual_dim in zip(expected_output_shape,
                                      actual_output_shape):
    if expected_dim is not None:
      assert expected_dim == actual_dim
  if expected_output is not None:
    np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3)

  # test serialization, weight setting at model level
  model_config = model.get_config()
  recovered_model = keras.models.Sequential.from_config(model_config)
  if model.weights:
    weights = model.get_weights()
    recovered_model.set_weights(weights)
    output = recovered_model.predict(input_data)
    np.testing.assert_allclose(output, actual_output, rtol=1e-3)

  # test training mode (e.g. useful for dropout tests)
  model.compile('rmsprop', 'mse')
  model.train_on_batch(input_data, actual_output)

  # for further checks in the caller function
  return actual_output