1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for StringToNumber op from parsing_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
_ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: "
class StringToNumberOpTest(test.TestCase):
def _test(self, tf_type, good_pairs, bad_pairs):
with self.cached_session():
# Build a small testing graph.
input_string = array_ops.placeholder(dtypes.string)
output = parsing_ops.string_to_number(
input_string, out_type=tf_type)
# Check all the good input/output pairs.
for instr, outnum in good_pairs:
result, = output.eval(feed_dict={input_string: [instr]})
self.assertAllClose([outnum], [result])
# Check that the bad inputs produce the right errors.
for instr, outstr in bad_pairs:
with self.assertRaisesOpError(outstr):
output.eval(feed_dict={input_string: [instr]})
def testToFloat(self):
self._test(dtypes.float32,
[("0", 0), ("3", 3), ("-1", -1),
("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5),
("3.40282e+38", 3.40282e+38),
# Greater than max value of float.
("3.40283e+38", float("INF")),
("-3.40283e+38", float("-INF")),
# Less than min value of float.
("NAN", float("NAN")),
("INF", float("INF"))],
[("10foobar", _ERROR_MESSAGE + "10foobar")])
def testToDouble(self):
self._test(dtypes.float64,
[("0", 0), ("3", 3), ("-1", -1),
("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5),
("3.40282e+38", 3.40282e+38),
# Greater than max value of float.
("3.40283e+38", 3.40283e+38),
# Less than min value of float.
("-3.40283e+38", -3.40283e+38),
("NAN", float("NAN")),
("INF", float("INF"))],
[("10foobar", _ERROR_MESSAGE + "10foobar")])
def testToInt32(self):
self._test(dtypes.int32,
[("0", 0), ("3", 3), ("-1", -1),
(" -10", -10),
("-2147483648", -2147483648),
("2147483647", 2147483647)],
[ # Less than min value of int32.
("-2147483649", _ERROR_MESSAGE + "-2147483649"),
# Greater than max value of int32.
("2147483648", _ERROR_MESSAGE + "2147483648"),
("2.9", _ERROR_MESSAGE + "2.9"),
("10foobar", _ERROR_MESSAGE + "10foobar")])
def testToInt64(self):
self._test(dtypes.int64,
[("0", 0), ("3", 3), ("-1", -1),
(" -10", -10),
("-2147483648", -2147483648),
("2147483647", 2147483647),
("-2147483649", -2147483649), # Less than min value of int32.
("2147483648", 2147483648)], # Greater than max value of int32.
[("2.9", _ERROR_MESSAGE + "2.9"),
("10foobar", _ERROR_MESSAGE + "10foobar")])
if __name__ == "__main__":
test.main()
|