# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for Substr op from string_ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from absl.testing import parameterized import numpy as np from tensorflow.python.framework import errors_impl from tensorflow.python.ops import string_ops from tensorflow.python.platform import test class SubstrOpTest(test.TestCase, parameterized.TestCase): @parameterized.parameters( (np.int32, 1, "BYTE"), (np.int64, 1, "BYTE"), (np.int32, -4, "BYTE"), (np.int64, -4, "BYTE"), (np.int32, 1, "UTF8_CHAR"), (np.int64, 1, "UTF8_CHAR"), (np.int32, -4, "UTF8_CHAR"), (np.int64, -4, "UTF8_CHAR"), ) def testScalarString(self, dtype, pos, unit): test_string = { "BYTE": b"Hello", "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"), }[unit] expected_value = { "BYTE": b"ell", "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"), }[unit] position = np.array(pos, dtype) length = np.array(3, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testScalarString_EdgeCases(self, dtype, unit): # Empty string test_string = { "BYTE": b"", "UTF8_CHAR": u"".encode("utf-8"), }[unit] expected_value = b"" position = np.array(0, dtype) length = np.array(3, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) # Full string test_string = { "BYTE": b"Hello", "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), }[unit] position = np.array(0, dtype) length = np.array(5, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, test_string) # Full string (Negative) test_string = { "BYTE": b"Hello", "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), }[unit] position = np.array(-5, dtype) length = np.array(5, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, test_string) # Length is larger in magnitude than a negative position test_string = { "BYTE": b"Hello", "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), }[unit] expected_string = { "BYTE": b"ello", "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"), }[unit] position = np.array(-4, dtype) length = np.array(5, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_string) @parameterized.parameters( (np.int32, 1, "BYTE"), (np.int64, 1, "BYTE"), (np.int32, -4, "BYTE"), (np.int64, -4, "BYTE"), (np.int32, 1, "UTF8_CHAR"), (np.int64, 1, "UTF8_CHAR"), (np.int32, -4, "UTF8_CHAR"), (np.int64, -4, "UTF8_CHAR"), ) def testVectorStrings(self, dtype, pos, unit): test_string = { "BYTE": [b"Hello", b"World"], "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo", u"W\U0001f604rld"]], }[unit] expected_value = { "BYTE": [b"ell", b"orl"], "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]], }[unit] position = np.array(pos, dtype) length = np.array(3, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testMatrixStrings(self, dtype, unit): test_string = { "BYTE": [[b"ten", b"eleven", b"twelve"], [b"thirteen", b"fourteen", b"fifteen"], [b"sixteen", b"seventeen", b"eighteen"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", u"\xc6\u053c\u025bv\u025bn", u"tw\u0c1dlv\u025b"]], [x.encode("utf-8") for x in [u"He\xc3\xc3o", u"W\U0001f604rld", u"d\xfcd\xea"]]], }[unit] position = np.array(1, dtype) length = np.array(4, dtype) expected_value = { "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"], [b"ixte", b"even", b"ight"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n", u"\u053c\u025bv\u025b", u"w\u0c1dlv"]], [x.encode("utf-8") for x in [u"e\xc3\xc3o", u"\U0001f604rld", u"\xfcd\xea"]]], }[unit] substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) position = np.array(-3, dtype) length = np.array(2, dtype) expected_value = { "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"], [b"ee", b"ee", b"ee"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227", u"v\u025b", u"lv"]], [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl", u"\xfcd"]]], }[unit] substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testElementWisePosLen(self, dtype, unit): test_string = { "BYTE": [[b"ten", b"eleven", b"twelve"], [b"thirteen", b"fourteen", b"fifteen"], [b"sixteen", b"seventeen", b"eighteen"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", u"\xc6\u053c\u025bv\u025bn", u"tw\u0c1dlv\u025b"]], [x.encode("utf-8") for x in [u"He\xc3\xc3o", u"W\U0001f604rld", u"d\xfcd\xea"]], [x.encode("utf-8") for x in [u"sixt\xea\xean", u"se\U00010299enteen", u"ei\U0001e920h\x86een"]]], }[unit] position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype) length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype) expected_value = { "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"], [b"xteen", b"vente", b"hteen"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n", u"\u025bv", u"lv\u025b"]], [x.encode("utf-8") for x in [u"e\xc3\xc3o", u"rld", u"d\xfc"]], [x.encode("utf-8") for x in [u"xt\xea\xean", u"\U00010299ente", u"h\x86een"]]], }[unit] substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testBroadcast(self, dtype, unit): # Broadcast pos/len onto input string test_string = { "BYTE": [[b"ten", b"eleven", b"twelve"], [b"thirteen", b"fourteen", b"fifteen"], [b"sixteen", b"seventeen", b"eighteen"], [b"nineteen", b"twenty", b"twentyone"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", u"\xc6\u053c\u025bv\u025bn", u"tw\u0c1dlv\u025b"]], [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", u"f\U0001f604urt\xea\xean", u"f\xcd\ua09ctee\ua0e4"]], [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean", u"se\U00010299enteen", u"ei\U0001e920h\x86een"]], [x.encode("utf-8") for x in [u"nineteen", u"twenty", u"twentyone"]]], }[unit] position = np.array([1, -4, 3], dtype) length = np.array([1, 2, 3], dtype) expected_value = { "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"], [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227", u"\u025bv", u"lv\u025b"]], [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]], [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]], [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]], }[unit] substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) # Broadcast input string onto pos/len test_string = { "BYTE": [b"thirteen", b"fourteen", b"fifteen"], "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", u"f\U0001f604urt\xea\xean", u"f\xcd\ua09ctee\ua0e4"]], }[unit] position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) expected_value = { "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"], [b"ee", b"ee", b"ft"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]], [x.encode("utf-8") for x in [u"\xea", u"ur", u"\xcd\ua09ct"]], [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea", u"\ua09ct"]]], }[unit] substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) # Test 1D broadcast test_string = { "BYTE": b"thirteen", "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"), }[unit] position = np.array([1, -4, 7], dtype) length = np.array([3, 2, 1], dtype) expected_value = { "BYTE": [b"hir", b"te", b"n"], "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]], }[unit] substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testBadBroadcast(self, dtype, unit): test_string = [[b"ten", b"eleven", b"twelve"], [b"thirteen", b"fourteen", b"fifteen"], [b"sixteen", b"seventeen", b"eighteen"]] position = np.array([1, 2, -3, 4], dtype) length = np.array([1, 2, 3, 4], dtype) with self.assertRaises(ValueError): string_ops.substr(test_string, position, length, unit=unit) @parameterized.parameters( (np.int32, 6, "BYTE"), (np.int64, 6, "BYTE"), (np.int32, -6, "BYTE"), (np.int64, -6, "BYTE"), (np.int32, 6, "UTF8_CHAR"), (np.int64, 6, "UTF8_CHAR"), (np.int32, -6, "UTF8_CHAR"), (np.int64, -6, "UTF8_CHAR"), ) def testOutOfRangeError_Scalar(self, dtype, pos, unit): # Scalar/Scalar test_string = { "BYTE": b"Hello", "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), }[unit] position = np.array(pos, dtype) length = np.array(3, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): substr_op.eval() @parameterized.parameters( (np.int32, 4, "BYTE"), (np.int64, 4, "BYTE"), (np.int32, -4, "BYTE"), (np.int64, -4, "BYTE"), (np.int32, 4, "UTF8_CHAR"), (np.int64, 4, "UTF8_CHAR"), (np.int32, -4, "UTF8_CHAR"), (np.int64, -4, "UTF8_CHAR"), ) def testOutOfRangeError_VectorScalar(self, dtype, pos, unit): # Vector/Scalar test_string = { "BYTE": [b"good", b"good", b"bad", b"good"], "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d", u"g\xc3\xc3d"]], }[unit] position = np.array(pos, dtype) length = np.array(1, dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): substr_op.eval() @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testOutOfRangeError_MatrixMatrix(self, dtype, unit): # Matrix/Matrix test_string = { "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], [b"good", b"good", b"good"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"g\xc3\xc3d"]], [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"b\xc3d"]], [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"g\xc3\xc3d"]]], }[unit] position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): substr_op.eval() # Matrix/Matrix (with negative) position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): substr_op.eval() @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testOutOfRangeError_Broadcast(self, dtype, unit): # Broadcast test_string = { "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"g\xc3\xc3d"]], [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", u"b\xc3d"]]], }[unit] position = np.array([1, 2, 4], dtype) length = np.array([1, 2, 3], dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): substr_op.eval() # Broadcast (with negative) position = np.array([-1, -2, -4], dtype) length = np.array([1, 2, 3], dtype) substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): substr_op.eval() @parameterized.parameters( (np.int32, "BYTE"), (np.int64, "BYTE"), (np.int32, "UTF8_CHAR"), (np.int64, "UTF8_CHAR"), ) def testMismatchPosLenShapes(self, dtype, unit): test_string = { "BYTE": [[b"ten", b"eleven", b"twelve"], [b"thirteen", b"fourteen", b"fifteen"], [b"sixteen", b"seventeen", b"eighteen"]], "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", u"\xc6\u053c\u025bv\u025bn", u"tw\u0c1dlv\u025b"]], [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", u"f\U0001f604urt\xea\xean", u"f\xcd\ua09ctee\ua0e4"]], [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean", u"se\U00010299enteen", u"ei\U0001e920h\x86een"]]], }[unit] position = np.array([[1, 2, 3]], dtype) length = np.array([2, 3, 4], dtype) # Should fail: position/length have different rank with self.assertRaises(ValueError): string_ops.substr(test_string, position, length) position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype) length = np.array([[2, 3, 4]], dtype) # Should fail: position/length have different dimensionality with self.assertRaises(ValueError): string_ops.substr(test_string, position, length) def testWrongDtype(self): with self.cached_session(): with self.assertRaises(TypeError): string_ops.substr(b"test", 3.0, 1) with self.assertRaises(TypeError): string_ops.substr(b"test", 3, 1.0) def testInvalidUnit(self): with self.cached_session(): with self.assertRaises(ValueError): string_ops.substr(b"test", 3, 1, unit="UTF8") if __name__ == "__main__": test.main()