aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/py_func_test.py
blob: f314712d7cbd628731ac28eac76ec3e12ec49f46 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for py_func op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from six.moves import queue
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import script_ops


class PyOpTest(tf.test.TestCase):

  def testBasic(self):

    def my_func(x, y):
      return np.sinh(x) + np.cosh(y)

    # single type
    with self.test_session():
      x = tf.constant(1.0, tf.float32)
      y = tf.constant(2.0, tf.float32)
      z = tf.py_func(my_func, [x, y], tf.float32)
      self.assertEqual(z.eval(), my_func(1.0, 2.0).astype(np.float32))

    # scalar
    with self.test_session():
      x = tf.constant(1.0, tf.float32)
      y = tf.constant(2.0, tf.float32)
      z = tf.py_func(my_func, [x, y], [tf.float32])
      self.assertEqual(z[0].eval(), my_func(1.0, 2.0).astype(np.float32))

    # array
    with self.test_session():
      x = tf.constant([1.0, 2.0], tf.float64)
      y = tf.constant([2.0, 3.0], tf.float64)
      z = tf.py_func(my_func, [x, y], [tf.float64])
      self.assertAllEqual(
          z[0].eval(),
          my_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))

    # a bit exotic type (complex64)
    with self.test_session():
      x = tf.constant(1+2j, tf.complex64)
      y = tf.constant(3+4j, tf.complex64)
      z, = tf.py_func(my_func, [x, y], [tf.complex64])
      self.assertAllClose(z.eval(), my_func(1+2j, 3+4j))

    # a bit excotic function (rfft)
    with self.test_session():
      x = tf.constant([1., 2., 3., 4.], tf.float32)
      def rfft(x):
        return np.fft.rfft(x).astype(np.complex64)
      y, = tf.py_func(rfft, [x], [tf.complex64])
      self.assertAllClose(y.eval(), np.fft.rfft([1., 2., 3., 4.]))

    # returns a python literal.
    with self.test_session():
      def literal(x):
        return 1.0 if x == 0.0 else 0.0
      x = tf.constant(0.0, tf.float64)
      y, = tf.py_func(literal, [x], [tf.float64])
      self.assertAllClose(y.eval(), 1.0)

    # returns a list
    with self.test_session():
      def list_func(x):
        return [x, x + 1]
      x = tf.constant(0.0, tf.float64)
      y, z = tf.py_func(list_func, [x], [tf.float64] * 2)
      self.assertAllClose(y.eval(), 0.0)
      self.assertAllClose(z.eval(), 1.0)

    # returns a tuple
    with self.test_session():
      def tuple_func(x):
        return x, x + 1
      x = tf.constant(0.0, tf.float64)
      y, z = tf.py_func(tuple_func, [x], [tf.float64] * 2)
      self.assertAllClose(y.eval(), 0.0)
      self.assertAllClose(z.eval(), 1.0)

  def testStrings(self):

    def read_fixed_length_numpy_strings():
      return np.array([b" there"])

    def read_and_return_strings(x, y):
      return x + y

    with self.test_session():
      x = tf.constant([b"hello", b"hi"], tf.string)
      y, = tf.py_func(read_fixed_length_numpy_strings, [], [tf.string])
      z, = tf.py_func(read_and_return_strings, [x, y], [tf.string])
      self.assertListEqual(list(z.eval()), [b"hello there", b"hi there"])

  def testStringPadding(self):
    correct = [b"this", b"is", b"a", b"test"]
    with self.test_session():
      s, = tf.py_func(lambda: [correct], [], [tf.string])
      self.assertAllEqual(s.eval(), correct)

  def testLarge(self):
    with self.test_session() as sess:
      x = tf.zeros([1000000], dtype=np.float32)
      y = tf.py_func(lambda x: x + 1, [x], [tf.float32])
      z = tf.py_func(lambda x: x * 2, [x], [tf.float32])
      for _ in xrange(100):
        sess.run([y[0].op, z[0].op])

  def testNoInput(self):
    with self.test_session():
      x, = tf.py_func(lambda: 42.0, [], [tf.float64])
      self.assertAllClose(x.eval(), 42.0)

  def testCleanup(self):
    for _ in xrange(1000):
      g = tf.Graph()
      with g.as_default():
        c = tf.constant([1.], tf.float32)
        _ = tf.py_func(lambda x: x + 1, [c], [tf.float32])
    self.assertTrue(script_ops._py_funcs.size() < 100)

  def testBadNumpyReturnType(self):
    with self.test_session():

      def bad():
        # Structured numpy arrays aren't supported.
        return np.array([], dtype=[("foo", np.float32)])

      y, = tf.py_func(bad, [], [tf.float32])

      with self.assertRaisesRegexp(errors.UnimplementedError,
                                   "Unsupported numpy type"):
        y.eval()

  def testBadReturnType(self):
    with self.test_session():

      def bad():
        # Non-string python objects aren't supported.
        return tf.float32

      z, = tf.py_func(bad, [], [tf.float64])

      with self.assertRaisesRegexp(errors.UnimplementedError,
                                   "Unsupported object type"):
        z.eval()

  def testStateful(self):
    # Not using self.test_session(), which disables optimization.
    with tf.Session() as sess:
      producer = iter(range(3))
      x, = tf.py_func(lambda: next(producer), [], [tf.int64])
      self.assertEqual(sess.run(x), 0)
      self.assertEqual(sess.run(x), 1)
      self.assertEqual(sess.run(x), 2)

  def testStateless(self):
    # Not using self.test_session(), which disables optimization.
    with tf.Session() as sess:
      producer = iter(range(3))
      x, = tf.py_func(lambda: next(producer), [], [tf.int64], stateful=False)
      self.assertEqual(sess.run(x), 0)
      self.assertEqual(sess.run(x), 0)
      self.assertEqual(sess.run(x), 0)

  def testGradientFunction(self):
    # Input to tf.py_func is necessary, otherwise get_gradient_function()
    # returns None per default.
    a = tf.constant(0)
    x, = tf.py_func(lambda a: 0, [a], [tf.int64])
    y, = tf.py_func(lambda a: 0, [a], [tf.int64], stateful=False)
    self.assertEqual(None, ops.get_gradient_function(x.op))
    self.assertEqual(None, ops.get_gradient_function(y.op))

  def testCOrder(self):
    with self.test_session():
      val = [[1, 2], [3, 4]]
      x, = tf.py_func(lambda: np.array(val, order="F"), [], [tf.int64])
      self.assertAllEqual(val, x.eval())

  def testParallel(self):
    # Tests that tf.py_func's can run in parallel if they release the GIL.
    with self.test_session() as session:
      q = queue.Queue(1)

      def blocking_put():
        q.put(42)
        q.join()  # Wait for task_done().
        return 42

      def blocking_get():
        v = q.get(block=True)  # Wait for put().
        q.task_done()
        return v

      x, = tf.py_func(blocking_put, [], [tf.int64])
      y, = tf.py_func(blocking_get, [], [tf.int64])

      # This will result in a deadlock if the py_func's don't run in parallel.
      session.run([x, y])


if __name__ == "__main__":
  tf.test.main()