diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-23 08:47:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-23 10:13:43 -0700 |
commit | fff77b855e6fc189dc342d425ac2a5b21e7c0e53 (patch) | |
tree | 914581a81110efb7c665abad906bf1b4ac7c857c | |
parent | 2226b6f600a1c9beb3c21dee8819551e8b4a0a05 (diff) |
tf.string_to_number: Add support for int64 and float64.
Change: 151015024
-rw-r--r-- | tensorflow/core/kernels/string_to_number_op.cc | 16 | ||||
-rw-r--r-- | tensorflow/core/ops/parsing_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/string_to_number_op_test.py | 101 |
3 files changed, 75 insertions, 44 deletions
diff --git a/tensorflow/core/kernels/string_to_number_op.cc b/tensorflow/core/kernels/string_to_number_op.cc index 6812671091..d583e4e6bb 100644 --- a/tensorflow/core/kernels/string_to_number_op.cc +++ b/tensorflow/core/kernels/string_to_number_op.cc @@ -66,12 +66,26 @@ void StringToNumberOp<float>::Convert(const string& s, float* output_data, } template <> +void StringToNumberOp<double>::Convert(const string& s, double* output_data, + OpKernelContext* context) { + OP_REQUIRES(context, strings::safe_strtod(s.c_str(), output_data), + errors::InvalidArgument(kErrorMessage, s)); +} + +template <> void StringToNumberOp<int32>::Convert(const string& s, int32* output_data, OpKernelContext* context) { OP_REQUIRES(context, strings::safe_strto32(s, output_data), errors::InvalidArgument(kErrorMessage, s)); } +template <> +void StringToNumberOp<int64>::Convert(const string& s, int64* output_data, + OpKernelContext* context) { + OP_REQUIRES(context, strings::safe_strto64(s, output_data), + errors::InvalidArgument(kErrorMessage, s)); +} + // Registers the currently supported output types. #define REGISTER(type) \ REGISTER_KERNEL_BUILDER(Name("StringToNumber") \ @@ -79,7 +93,9 @@ void StringToNumberOp<int32>::Convert(const string& s, int32* output_data, .TypeConstraint<type>("out_type"), \ StringToNumberOp<type>) REGISTER(float); +REGISTER(double); REGISTER(int32); +REGISTER(int64); #undef REGISTER } // namespace tensorflow diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index b563656f39..2af2955c19 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -354,7 +354,7 @@ output: Each tensor will have the same shape as records. REGISTER_OP("StringToNumber") .Input("string_tensor: string") .Output("output: out_type") - .Attr("out_type: {float, int32} = DT_FLOAT") + .Attr("out_type: {float, double, int32, int64} = DT_FLOAT") .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Converts each string in the input Tensor to the specified numeric type. diff --git a/tensorflow/python/kernel_tests/string_to_number_op_test.py b/tensorflow/python/kernel_tests/string_to_number_op_test.py index 8a7a7285a6..cc4c21b66c 100644 --- a/tensorflow/python/kernel_tests/string_to_number_op_test.py +++ b/tensorflow/python/kernel_tests/string_to_number_op_test.py @@ -28,57 +28,72 @@ _ERROR_MESSAGE = "StringToNumberOp could not correctly convert string: " class StringToNumberOpTest(test.TestCase): - def testToFloat(self): + def _test(self, tf_type, good_pairs, bad_pairs): with self.test_session(): + # Build a small testing graph. input_string = array_ops.placeholder(dtypes.string) output = parsing_ops.string_to_number( - input_string, out_type=dtypes.float32) + input_string, out_type=tf_type) - result = output.eval(feed_dict={ - input_string: [ - "0", - "3", - "-1", - "1.12", - "0xF", - " -10.5", - "3.40282e+38", - # The next two exceed maximum value for float, so we - # expect +/-INF to be returned instead. - "3.40283e+38", - "-3.40283e+38", - "NAN", - "INF" - ] - }) + # Check all the good input/output pairs. + for instr, outnum in good_pairs: + result, = output.eval(feed_dict={input_string: [instr]}) + self.assertAllClose([outnum], [result]) - self.assertAllClose([ - 0, 3, -1, 1.12, 0xF, -10.5, 3.40282e+38, float("INF"), float("-INF"), - float("NAN"), float("INF") - ], result) + # Check that the bad inputs produce the right errors. + for instr, outstr in bad_pairs: + with self.assertRaisesOpError(outstr): + output.eval(feed_dict={input_string: [instr]}) - with self.assertRaisesOpError(_ERROR_MESSAGE + "10foobar"): - output.eval(feed_dict={input_string: ["10foobar"]}) + def testToFloat(self): + self._test(dtypes.float32, + [("0", 0), ("3", 3), ("-1", -1), + ("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5), + ("3.40282e+38", 3.40282e+38), + # Greater than max value of float. + ("3.40283e+38", float("INF")), + ("-3.40283e+38", float("-INF")), + # Less than min value of float. + ("NAN", float("NAN")), + ("INF", float("INF"))], + [("10foobar", _ERROR_MESSAGE + "10foobar")]) + + def testToDouble(self): + self._test(dtypes.float64, + [("0", 0), ("3", 3), ("-1", -1), + ("1.12", 1.12), ("0xF", 15), (" -10.5", -10.5), + ("3.40282e+38", 3.40282e+38), + # Greater than max value of float. + ("3.40283e+38", 3.40283e+38), + # Less than min value of float. + ("-3.40283e+38", -3.40283e+38), + ("NAN", float("NAN")), + ("INF", float("INF"))], + [("10foobar", _ERROR_MESSAGE + "10foobar")]) def testToInt32(self): - with self.test_session(): - input_string = array_ops.placeholder(dtypes.string) - output = parsing_ops.string_to_number(input_string, out_type=dtypes.int32) - - result = output.eval(feed_dict={ - input_string: - ["0", "3", "-1", " -10", "-2147483648", "2147483647"] - }) - - self.assertAllEqual([0, 3, -1, -10, -2147483648, 2147483647], result) - - with self.assertRaisesOpError(_ERROR_MESSAGE + "2.9"): - output.eval(feed_dict={input_string: ["2.9"]}) - - # The next two exceed maximum value of int32. - for in_string in ["-2147483649", "2147483648"]: - with self.assertRaisesOpError(_ERROR_MESSAGE + in_string): - output.eval(feed_dict={input_string: [in_string]}) + self._test(dtypes.int32, + [("0", 0), ("3", 3), ("-1", -1), + (" -10", -10), + ("-2147483648", -2147483648), + ("2147483647", 2147483647)], + [ # Less than min value of int32. + ("-2147483649", _ERROR_MESSAGE + "-2147483649"), + # Greater than max value of int32. + ("2147483648", _ERROR_MESSAGE + "2147483648"), + ("2.9", _ERROR_MESSAGE + "2.9"), + ("10foobar", _ERROR_MESSAGE + "10foobar")]) + + def testToInt64(self): + self._test(dtypes.int64, + [("0", 0), ("3", 3), ("-1", -1), + (" -10", -10), + ("-2147483648", -2147483648), + ("2147483647", 2147483647), + ("-2147483649", -2147483649), # Less than min value of int32. + ("2147483648", 2147483648)], # Greater than max value of int32. + [("2.9", _ERROR_MESSAGE + "2.9"), + ("10foobar", _ERROR_MESSAGE + "10foobar")]) if __name__ == "__main__": |