aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-23 08:47:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-23 10:13:43 -0700
commitfff77b855e6fc189dc342d425ac2a5b21e7c0e53 (patch)
tree914581a81110efb7c665abad906bf1b4ac7c857c
parent2226b6f600a1c9beb3c21dee8819551e8b4a0a05 (diff)
tf.string_to_number: Add support for int64 and float64.
Change: 151015024
-rw-r--r--tensorflow/core/kernels/string_to_number_op.cc16
-rw-r--r--tensorflow/core/ops/parsing_ops.cc2
-rw-r--r--tensorflow/python/kernel_tests/string_to_number_op_test.py101
3 files changed, 75 insertions, 44 deletions
diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc
index 6812671091..d583e4e6bb 100644
--- a/tensorflow/core/kernels/string_to_number_op.cc
+++ b/tensorflow/core/kernels/string_to_number_op.cc
@@ -66,12 +66,26 @@ void StringToNumberOp<float>::Convert(const string& s, float* output_data,
}
template <>
+void StringToNumberOp<double>::Convert(const string& s, double* output_data,
+ OpKernelContext* context) {
+ OP_REQUIRES(context, strings::safe_strtod(s.c_str(), output_data),
+ errors::InvalidArgument(kErrorMessage, s));
+}
+
+template <>
void StringToNumberOp<int32>::Convert(const string& s, int32* output_data,
OpKernelContext* context) {
OP_REQUIRES(context, strings::safe_strto32(s, output_data),
errors::InvalidArgument(kErrorMessage, s));
}
+template <>
+void StringToNumberOp<int64>::Convert(const string& s, int64* output_data,
+ OpKernelContext* context) {
+ OP_REQUIRES(context, strings::safe_strto64(s, output_data),
+ errors::InvalidArgument(kErrorMessage, s));
+}
+
// Registers the currently supported output types.
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER(Name("StringToNumber") \
@@ -79,7 +93,9 @@ void StringToNumberOp<int32>::Convert(const string& s, int32* output_data,
.TypeConstraint<type>("out_type"), \
StringToNumberOp<type>)
REGISTER(float);
+REGISTER(double);
REGISTER(int32);
+REGISTER(int64);
#undef REGISTER
} // namespace tensorflow
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index b563656f39..2af2955c19 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -354,7 +354,7 @@ output: Each tensor will have the same shape as records.
REGISTER_OP("StringToNumber")
.Input("string_tensor: string")
.Output("output: out_type")
- .Attr("out_type: {float, int32} = DT_FLOAT")
+ .Attr("out_type: {float, double, int32, int64} = DT_FLOAT")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Converts each string in the input Tensor to the specified numeric type.
diff --git a/tensorflow/python/kernel_tests/string_to_number_op_test.py b/tensorflow/python/kernel_tests/string_to_number_op_test.py
index 8a7a7285a6..cc4c21b66c 100644
--- a/tensorflow/python/kernel_tests/string_to_number_op_test.py
+++ b/tensorflow/python/kernel_tests/string_to_number_op_test.py
@@ -28,57 +28,72 @@ _ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: "
class StringToNumberOpTest(test.TestCase):
- def testToFloat(self):
+ def _test(self, tf_type, good_pairs, bad_pairs):
with self.test_session():
+ # Build a small testing graph.
input_string = array_ops.placeholder(dtypes.string)
output = parsing_ops.string_to_number(
- input_string, out_type=dtypes.float32)
+ input_string, out_type=tf_type)
- result = output.eval(feed_dict={
- input_string: [
- "0",
- "3",
- "-1",
- "1.12",
- "0xF",
- " -10.5",
- "3.40282e+38",
- # The next two exceed maximum value for float, so we
- # expect +/-INF to be returned instead.
- "3.40283e+38",
- "-3.40283e+38",
- "NAN",
- "INF"
- ]
- })
+ # Check all the good input/output pairs.
+ for instr, outnum in good_pairs:
+ result, = output.eval(feed_dict={input_string: [instr]})
+ self.assertAllClose([outnum], [result])
- self.assertAllClose([
- 0, 3, -1, 1.12, 0xF, -10.5, 3.40282e+38, float("INF"), float("-INF"),
- float("NAN"), float("INF")
- ], result)
+ # Check that the bad inputs produce the right errors.
+ for instr, outstr in bad_pairs:
+ with self.assertRaisesOpError(outstr):
+ output.eval(feed_dict={input_string: [instr]})
- with self.assertRaisesOpError(_ERROR_MESSAGE + "10foobar"):
- output.eval(feed_dict={input_string: ["10foobar"]})
+ def testToFloat(self):
+ self._test(dtypes.float32,
+ [("0", 0), ("3", 3), ("-1", -1),
+ ("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5),
+ ("3.40282e+38", 3.40282e+38),
+ # Greater than max value of float.
+ ("3.40283e+38", float("INF")),
+ ("-3.40283e+38", float("-INF")),
+ # Less than min value of float.
+ ("NAN", float("NAN")),
+ ("INF", float("INF"))],
+ [("10foobar", _ERROR_MESSAGE + "10foobar")])
+
+ def testToDouble(self):
+ self._test(dtypes.float64,
+ [("0", 0), ("3", 3), ("-1", -1),
+ ("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5),
+ ("3.40282e+38", 3.40282e+38),
+ # Greater than max value of float.
+ ("3.40283e+38", 3.40283e+38),
+ # Less than min value of float.
+ ("-3.40283e+38", -3.40283e+38),
+ ("NAN", float("NAN")),
+ ("INF", float("INF"))],
+ [("10foobar", _ERROR_MESSAGE + "10foobar")])
def testToInt32(self):
- with self.test_session():
- input_string = array_ops.placeholder(dtypes.string)
- output = parsing_ops.string_to_number(input_string, out_type=dtypes.int32)
-
- result = output.eval(feed_dict={
- input_string:
- ["0", "3", "-1", " -10", "-2147483648", "2147483647"]
- })
-
- self.assertAllEqual([0, 3, -1, -10, -2147483648, 2147483647], result)
-
- with self.assertRaisesOpError(_ERROR_MESSAGE + "2.9"):
- output.eval(feed_dict={input_string: ["2.9"]})
-
- # The next two exceed maximum value of int32.
- for in_string in ["-2147483649", "2147483648"]:
- with self.assertRaisesOpError(_ERROR_MESSAGE + in_string):
- output.eval(feed_dict={input_string: [in_string]})
+ self._test(dtypes.int32,
+ [("0", 0), ("3", 3), ("-1", -1),
+ (" -10", -10),
+ ("-2147483648", -2147483648),
+ ("2147483647", 2147483647)],
+ [ # Less than min value of int32.
+ ("-2147483649", _ERROR_MESSAGE + "-2147483649"),
+ # Greater than max value of int32.
+ ("2147483648", _ERROR_MESSAGE + "2147483648"),
+ ("2.9", _ERROR_MESSAGE + "2.9"),
+ ("10foobar", _ERROR_MESSAGE + "10foobar")])
+
+ def testToInt64(self):
+ self._test(dtypes.int64,
+ [("0", 0), ("3", 3), ("-1", -1),
+ (" -10", -10),
+ ("-2147483648", -2147483648),
+ ("2147483647", 2147483647),
+ ("-2147483649", -2147483649), # Less than min value of int32.
+ ("2147483648", 2147483648)], # Greater than max value of int32.
+ [("2.9", _ERROR_MESSAGE + "2.9"),
+ ("10foobar", _ERROR_MESSAGE + "10foobar")])
if __name__ == "__main__":