aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/aggregate_ops_test.py
blob: 0a08c01dad38f9b31b775b25f01db2e4361df552 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for aggregate_ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.core.framework import tensor_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test


class AddNTest(test.TestCase):
  # AddN special-cases adding the first M inputs to make (N - M) divisible by 8,
  # after which it adds the remaining (N - M) tensors 8 at a time in a loop.
  # Test N in [1, 10] so we check each special-case from 1 to 9 and one
  # iteration of the loop.
  _MAX_N = 10

  def _supported_types(self):
    if test.is_gpu_available():
      return [dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
              dtypes.complex128]
    return [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64,
            dtypes.float16, dtypes.float32, dtypes.float64, dtypes.complex64,
            dtypes.complex128]

  def _buildData(self, shape, dtype):
    data = np.random.randn(*shape).astype(dtype.as_numpy_dtype)
    # For complex types, add an index-dependent imaginary component so we can
    # tell we got the right value.
    if dtype.is_complex:
      return data + 10j * data
    return data

  def testAddN(self):
    np.random.seed(12345)
    with self.test_session(use_gpu=True) as sess:
      for dtype in self._supported_types():
        for count in range(1, self._MAX_N + 1):
          data = [self._buildData((2, 2), dtype) for _ in range(count)]
          actual = sess.run(math_ops.add_n(data))
          expected = np.sum(np.vstack(
              [np.expand_dims(d, 0) for d in data]), axis=0)
          tol = 5e-3 if dtype == dtypes.float16 else 5e-7
          self.assertAllClose(expected, actual, rtol=tol, atol=tol)

  def testUnknownShapes(self):
    np.random.seed(12345)
    with self.test_session(use_gpu=True) as sess:
      for dtype in self._supported_types():
        data = self._buildData((2, 2), dtype)
        for count in range(1, self._MAX_N + 1):
          data_ph = array_ops.placeholder(dtype=dtype)
          actual = sess.run(math_ops.add_n([data_ph] * count), {data_ph: data})
          expected = np.sum(np.vstack([np.expand_dims(data, 0)] * count),
                            axis=0)
          tol = 5e-3 if dtype == dtypes.float16 else 5e-7
          self.assertAllClose(expected, actual, rtol=tol, atol=tol)

  def testVariant(self):

    def create_constant_variant(value):
      return constant_op.constant(
          tensor_pb2.TensorProto(
              dtype=dtypes.variant.as_datatype_enum,
              tensor_shape=tensor_shape.TensorShape([]).as_proto(),
              variant_val=[
                  tensor_pb2.VariantTensorDataProto(
                      # Match registration in variant_op_registry.cc
                      type_name=b"int",
                      metadata=np.array(value, dtype=np.int32).tobytes())
              ]))

    # TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant
    # copying between CPU and GPU is supported.
    with self.test_session(use_gpu=False):
      variant_const_3 = create_constant_variant(3)
      variant_const_4 = create_constant_variant(4)
      variant_const_5 = create_constant_variant(5)
      # 3 + 3 + 5 + 4 = 15.
      result = math_ops.add_n((variant_const_3, variant_const_3,
                               variant_const_5, variant_const_4))

      # Smoke test -- ensure this executes without trouble.
      # Right now, non-numpy-compatible objects cannot be returned from a
      # session.run call; similarly, objects that can't be converted to
      # native numpy types cannot be passed to ops.convert_to_tensor.
      # For now, run the test and examine the output to see that the result is
      # equal to 15.
      result_op = logging_ops.Print(
          result, [variant_const_3, variant_const_4, variant_const_5, result],
          message=("Variants stored an int: c(3), c(4), c(5), "
                   "add_n(c(3), c(3), c(5), c(4)): ")).op
      result_op.run()


if __name__ == "__main__":
  test.main()