aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/training_gpu_test.py
blob: 5825ce814fd84bf59637f6079e7402d752e2b77b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for training routines."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python import keras
from tensorflow.python.framework import test_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.layers.convolutional import Conv2D
from tensorflow.python.platform import test
from tensorflow.python.training import rmsprop


class TrainingGPUTest(test.TestCase):

  @test_util.run_in_graph_and_eager_modes
  def test_model_with_crossentropy_losses_channels_first(self):
    """Tests use of all crossentropy losses with `channels_first`.

    Tests `sparse_categorical_crossentropy`, `categorical_crossentropy`,
    and `binary_crossentropy`.
    Verifies that evaluate gives the same result with either `channels_first`
    or `channels_last` image_data_format.
    """
    def prepare_simple_model(input_tensor, loss_name, target):
      axis = 1 if K.image_data_format() == 'channels_first' else -1
      loss = None
      num_channels = None
      activation = None
      if loss_name == 'sparse_categorical_crossentropy':
        loss = lambda y_true, y_pred: K.sparse_categorical_crossentropy(  # pylint: disable=g-long-lambda
            y_true, y_pred, axis=axis)
        num_channels = np.amax(target) + 1
        activation = 'softmax'
      elif loss_name == 'categorical_crossentropy':
        loss = lambda y_true, y_pred: K.categorical_crossentropy(  # pylint: disable=g-long-lambda
            y_true, y_pred, axis=axis)
        num_channels = target.shape[axis]
        activation = 'softmax'
      elif loss_name == 'binary_crossentropy':
        loss = lambda y_true, y_pred: K.binary_crossentropy(y_true, y_pred)  # pylint: disable=unnecessary-lambda
        num_channels = target.shape[axis]
        activation = 'sigmoid'
      predictions = Conv2D(num_channels,
                           1,
                           activation=activation,
                           kernel_initializer='ones',
                           bias_initializer='ones')(input_tensor)
      simple_model = keras.models.Model(inputs=input_tensor,
                                        outputs=predictions)
      simple_model.compile(optimizer=rmsprop.RMSPropOptimizer(1e-3), loss=loss)
      return simple_model

    if test.is_gpu_available(cuda_only=True):
      with self.test_session(use_gpu=True):
        losses_to_test = ['sparse_categorical_crossentropy',
                          'categorical_crossentropy', 'binary_crossentropy']

        data_channels_first = np.array([[[[8., 7.1, 0.], [4.5, 2.6, 0.55],
                                          [0.9, 4.2, 11.2]]]], dtype=np.float32)
        # Labels for testing 4-class sparse_categorical_crossentropy, 4-class
        # categorical_crossentropy, and 2-class binary_crossentropy:
        labels_channels_first = [np.array([[[[0, 1, 3], [2, 1, 0], [2, 2, 1]]]], dtype=np.float32),  # pylint: disable=line-too-long
                                 np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 0]],
                                            [[1, 0, 0], [0, 0, 1], [0, 1, 0]],
                                            [[0, 0, 0], [1, 0, 0], [0, 0, 1]],
                                            [[0, 0, 1], [0, 0, 0], [1, 0, 0]]]], dtype=np.float32),  # pylint: disable=line-too-long
                                 np.array([[[[0, 1, 0], [0, 1, 0], [0, 0, 1]],
                                            [[1, 0, 1], [1, 0, 1], [1, 1, 0]]]], dtype=np.float32)]  # pylint: disable=line-too-long
        # Compute one loss for each loss function in the list `losses_to_test`:
        loss_channels_last = [0., 0., 0.]
        loss_channels_first = [0., 0., 0.]

        old_data_format = K.image_data_format()

        # Evaluate a simple network with channels last, with all three loss
        # functions:
        K.set_image_data_format('channels_last')
        data = np.moveaxis(data_channels_first, 1, -1)
        for index, loss_function in enumerate(losses_to_test):
          labels = np.moveaxis(labels_channels_first[index], 1, -1)
          inputs = keras.Input(shape=(3, 3, 1))
          model = prepare_simple_model(inputs, loss_function, labels)
          loss_channels_last[index] = model.evaluate(x=data, y=labels,
                                                     batch_size=1, verbose=0)

        # Evaluate the same network with channels first, with all three loss
        # functions:
        K.set_image_data_format('channels_first')
        data = data_channels_first
        for index, loss_function in enumerate(losses_to_test):
          labels = labels_channels_first[index]
          inputs = keras.Input(shape=(1, 3, 3))
          model = prepare_simple_model(inputs, loss_function, labels)
          loss_channels_first[index] = model.evaluate(x=data, y=labels,
                                                      batch_size=1, verbose=0)

        K.set_image_data_format(old_data_format)

        np.testing.assert_allclose(loss_channels_first,
                                   loss_channels_last,
                                   err_msg='{}{}'.format(
                                       'Computed different losses for ',
                                       'channels_first and channels_last'))


if __name__ == '__main__':
  test.main()