aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/string_to_number_op_test.py
blob: 39505e18bac05fbc05fee6923f6660578cc71ceb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Tests for StringToNumber op from parsing_ops."""

import tensorflow.python.platform

import tensorflow as tf


_ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: "


class StringToNumberOpTest(tf.test.TestCase):

  def testToFloat(self):
    with self.test_session():
      input_string = tf.placeholder(tf.string)
      output = tf.string_to_number(
          input_string,
          out_type=tf.float32)

      result = output.eval(feed_dict={
          input_string: ["0",
                         "3",
                         "-1",
                         "1.12",
                         "0xF",
                         "   -10.5",
                         "3.40282e+38",
                         # The next two exceed maximum value for float, so we
                         # expect +/-INF to be returned instead.
                         "3.40283e+38",
                         "-3.40283e+38",
                         "NAN",
                         "INF"]
      })

      self.assertAllClose([0, 3, -1, 1.12, 0xF, -10.5, 3.40282e+38,
                           float("INF"), float("-INF"), float("NAN"),
                           float("INF")], result)

      with self.assertRaisesOpError(_ERROR_MESSAGE + "10foobar"):
        output.eval(feed_dict={input_string: ["10foobar"]})

  def testToInt32(self):
    with self.test_session():
      input_string = tf.placeholder(tf.string)
      output = tf.string_to_number(
          input_string,
          out_type=tf.int32)

      result = output.eval(feed_dict={
          input_string: ["0", "3", "-1", "    -10", "-2147483648", "2147483647"]
      })

      self.assertAllEqual([0, 3, -1, -10, -2147483648, 2147483647], result)

      with self.assertRaisesOpError(_ERROR_MESSAGE + "2.9"):
        output.eval(feed_dict={input_string: ["2.9"]})

      # The next two exceed maximum value of int32.
      for in_string in ["-2147483649", "2147483648"]:
        with self.assertRaisesOpError(_ERROR_MESSAGE + in_string):
          output.eval(feed_dict={input_string: [in_string]})


if __name__ == "__main__":
  tf.test.main()