diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-04 11:30:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 11:34:46 -0700 |
commit | 700c3325311e16be9bb4856cbf944d1871ff35c1 (patch) | |
tree | 9ae88328889950abaa951a628de7212caec8c026 /tensorflow/python/kernel_tests | |
parent | c8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (diff) |
Add "encoding" attribute to string substr op, which controls how each "character" is treated:
* BYTE: Position & length refer to bytes in the string. (Default)
* UTF8: The string is interpreted as UTF-8 encoded Unicode code points, and position & length are treated relative to them.
RELNOTES: Add option to get substring using Unicode characters
PiperOrigin-RevId: 215773373
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/substr_op_test.py | 503 |
1 files changed, 343 insertions, 160 deletions
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py index cd3fe14883..37aa624b07 100644 --- a/tensorflow/python/kernel_tests/substr_op_test.py +++ b/tensorflow/python/kernel_tests/substr_op_test.py @@ -28,270 +28,448 @@ from tensorflow.python.platform import test class SubstrOpTest(test.TestCase, parameterized.TestCase): - def _testScalarString(self, dtype): - test_string = b"Hello" - position = np.array(1, dtype) + @parameterized.parameters( + (np.int32, 1, "BYTE"), + (np.int64, 1, "BYTE"), + (np.int32, -4, "BYTE"), + (np.int64, -4, "BYTE"), + (np.int32, 1, "UTF8_CHAR"), + (np.int64, 1, "UTF8_CHAR"), + (np.int32, -4, "UTF8_CHAR"), + (np.int64, -4, "UTF8_CHAR"), + ) + def testScalarString(self, dtype, pos, unit): + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"), + }[unit] + expected_value = { + "BYTE": b"ell", + "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"), + }[unit] + position = np.array(pos, dtype) length = np.array(3, dtype) - expected_value = b"ell" - - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - # Negative position. - test_string = b"Hello" - position = np.array(-4, dtype) + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testScalarString_EdgeCases(self, dtype, unit): + # Empty string + test_string = { + "BYTE": b"", + "UTF8_CHAR": u"".encode("utf-8"), + }[unit] + expected_value = b"" + position = np.array(0, dtype) length = np.array(3, dtype) - expected_value = b"ell" - - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - # Position is equal to the length of string. - test_string = b"" + # Full string + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] position = np.array(0, dtype) - length = np.array(2, dtype) - expected_value = b"" - - substr_op = string_ops.substr(test_string, position, length) + length = np.array(5, dtype) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() - self.assertAllEqual(substr, expected_value) - - # Negative position magnitude is equal to the length of string. - test_string = b"yo" - position = np.array(-2, dtype) - length = np.array(1, dtype) - expected_value = b"y" - - substr_op = string_ops.substr(test_string, position, length) + self.assertAllEqual(substr, test_string) + + # Full string (Negative) + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] + position = np.array(-5, dtype) + length = np.array(5, dtype) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() - self.assertAllEqual(substr, expected_value) - - def _testVectorStrings(self, dtype): - test_string = [b"Hello", b"World"] - position = np.array(1, dtype) - length = np.array(3, dtype) - expected_value = [b"ell", b"orl"] - - substr_op = string_ops.substr(test_string, position, length) + self.assertAllEqual(substr, test_string) + + # Length is larger in magnitude than a negative position + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] + expected_string = { + "BYTE": b"ello", + "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"), + }[unit] + position = np.array(-4, dtype) + length = np.array(5, dtype) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() - self.assertAllEqual(substr, expected_value) - - # Negative position. - test_string = [b"Hello", b"World"] - position = np.array(-4, dtype) + self.assertAllEqual(substr, expected_string) + + @parameterized.parameters( + (np.int32, 1, "BYTE"), + (np.int64, 1, "BYTE"), + (np.int32, -4, "BYTE"), + (np.int64, -4, "BYTE"), + (np.int32, 1, "UTF8_CHAR"), + (np.int64, 1, "UTF8_CHAR"), + (np.int32, -4, "UTF8_CHAR"), + (np.int64, -4, "UTF8_CHAR"), + ) + def testVectorStrings(self, dtype, pos, unit): + test_string = { + "BYTE": [b"Hello", b"World"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo", + u"W\U0001f604rld"]], + }[unit] + expected_value = { + "BYTE": [b"ell", b"orl"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]], + }[unit] + position = np.array(pos, dtype) length = np.array(3, dtype) - expected_value = [b"ell", b"orl"] - - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testMatrixStrings(self, dtype): - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testMatrixStrings(self, dtype, unit): + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"He\xc3\xc3o", + u"W\U0001f604rld", + u"d\xfcd\xea"]]], + }[unit] position = np.array(1, dtype) length = np.array(4, dtype) - expected_value = [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"], - [b"ixte", b"even", b"ight"]] - - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"], + [b"ixte", b"even", b"ight"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n", + u"\u053c\u025bv\u025b", + u"w\u0c1dlv"]], + [x.encode("utf-8") for x in [u"e\xc3\xc3o", + u"\U0001f604rld", + u"\xfcd\xea"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - # Negative position - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] - position = np.array(-2, dtype) + position = np.array(-3, dtype) length = np.array(2, dtype) - expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"], - [b"en", b"en", b"en"]] - - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"], + [b"ee", b"ee", b"ee"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227", + u"v\u025b", u"lv"]], + [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl", + u"\xfcd"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testElementWisePosLen(self, dtype): - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testElementWisePosLen(self, dtype, unit): + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"He\xc3\xc3o", + u"W\U0001f604rld", + u"d\xfcd\xea"]], + [x.encode("utf-8") for x in [u"sixt\xea\xean", + u"se\U00010299enteen", + u"ei\U0001e920h\x86een"]]], + }[unit] position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype) length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype) - expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"], - [b"xteen", b"vente", b"hteen"]] - - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"], + [b"xteen", b"vente", b"hteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n", + u"\u025bv", + u"lv\u025b"]], + [x.encode("utf-8") for x in [u"e\xc3\xc3o", + u"rld", + u"d\xfc"]], + [x.encode("utf-8") for x in [u"xt\xea\xean", + u"\U00010299ente", + u"h\x86een"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testBroadcast(self, dtype): + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testBroadcast(self, dtype, unit): # Broadcast pos/len onto input string - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"], - [b"nineteen", b"twenty", b"twentyone"]] + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"], + [b"nineteen", b"twenty", b"twentyone"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", + u"f\U0001f604urt\xea\xean", + u"f\xcd\ua09ctee\ua0e4"]], + [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean", + u"se\U00010299enteen", + u"ei\U0001e920h\x86een"]], + [x.encode("utf-8") for x in [u"nineteen", + u"twenty", + u"twentyone"]]], + }[unit] position = np.array([1, -4, 3], dtype) length = np.array([1, 2, 3], dtype) - expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"], - [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]] - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"], + [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227", + u"\u025bv", u"lv\u025b"]], + [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]], + [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]], + [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) # Broadcast input string onto pos/len - test_string = [b"thirteen", b"fourteen", b"fifteen"] + test_string = { + "BYTE": [b"thirteen", b"fourteen", b"fifteen"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", + u"f\U0001f604urt\xea\xean", + u"f\xcd\ua09ctee\ua0e4"]], + }[unit] position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) - expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"], - [b"ee", b"ee", b"ft"]] - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"], + [b"ee", b"ee", b"ft"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]], + [x.encode("utf-8") for x in [u"\xea", u"ur", + u"\xcd\ua09ct"]], + [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea", + u"\ua09ct"]]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) # Test 1D broadcast - test_string = b"thirteen" - position = np.array([1, -5, 7], dtype) + test_string = { + "BYTE": b"thirteen", + "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"), + }[unit] + position = np.array([1, -4, 7], dtype) length = np.array([3, 2, 1], dtype) - expected_value = [b"hir", b"rt", b"n"] - substr_op = string_ops.substr(test_string, position, length) + expected_value = { + "BYTE": [b"hir", b"te", b"n"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]], + }[unit] + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): substr = substr_op.eval() self.assertAllEqual(substr, expected_value) - def _testBadBroadcast(self, dtype): + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testBadBroadcast(self, dtype, unit): test_string = [[b"ten", b"eleven", b"twelve"], [b"thirteen", b"fourteen", b"fifteen"], [b"sixteen", b"seventeen", b"eighteen"]] position = np.array([1, 2, -3, 4], dtype) length = np.array([1, 2, 3, 4], dtype) with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) - - def _testOutOfRangeError(self, dtype): + string_ops.substr(test_string, position, length, unit=unit) + + @parameterized.parameters( + (np.int32, 6, "BYTE"), + (np.int64, 6, "BYTE"), + (np.int32, -6, "BYTE"), + (np.int64, -6, "BYTE"), + (np.int32, 6, "UTF8_CHAR"), + (np.int64, 6, "UTF8_CHAR"), + (np.int32, -6, "UTF8_CHAR"), + (np.int64, -6, "UTF8_CHAR"), + ) + def testOutOfRangeError_Scalar(self, dtype, pos, unit): # Scalar/Scalar - test_string = b"Hello" - position = np.array(7, dtype) - length = np.array(3, dtype) - substr_op = string_ops.substr(test_string, position, length) - with self.cached_session(): - with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - - # Scalar/Scalar (with negative) - test_string = b"Hello" - position = np.array(-7, dtype) + test_string = { + "BYTE": b"Hello", + "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"), + }[unit] + position = np.array(pos, dtype) length = np.array(3, dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - + substr_op.eval() + + @parameterized.parameters( + (np.int32, 4, "BYTE"), + (np.int64, 4, "BYTE"), + (np.int32, -4, "BYTE"), + (np.int64, -4, "BYTE"), + (np.int32, 4, "UTF8_CHAR"), + (np.int64, 4, "UTF8_CHAR"), + (np.int32, -4, "UTF8_CHAR"), + (np.int64, -4, "UTF8_CHAR"), + ) + def testOutOfRangeError_VectorScalar(self, dtype, pos, unit): # Vector/Scalar - test_string = [b"good", b"good", b"bad", b"good"] - position = np.array(4, dtype) - length = np.array(1, dtype) - substr_op = string_ops.substr(test_string, position, length) - with self.cached_session(): - with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - - # Vector/Scalar (with negative) - test_string = [b"good", b"good", b"bad", b"good"] - position = np.array(-4, dtype) + test_string = { + "BYTE": [b"good", b"good", b"bad", b"good"], + "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d", + u"g\xc3\xc3d"]], + }[unit] + position = np.array(pos, dtype) length = np.array(1, dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - + substr_op.eval() + + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testOutOfRangeError_MatrixMatrix(self, dtype, unit): # Matrix/Matrix - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], - [b"good", b"good", b"good"]] + test_string = { + "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], + [b"good", b"good", b"good"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"g\xc3\xc3d"]], + [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"b\xc3d"]], + [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"g\xc3\xc3d"]]], + }[unit] position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() + substr_op.eval() # Matrix/Matrix (with negative) - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"], - [b"good", b"good", b"good"]] position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype) length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - + substr_op.eval() + + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testOutOfRangeError_Broadcast(self, dtype, unit): # Broadcast - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]] + test_string = { + "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"g\xc3\xc3d"]], + [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d", + u"b\xc3d"]]], + }[unit] position = np.array([1, 2, 4], dtype) length = np.array([1, 2, 3], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() + substr_op.eval() # Broadcast (with negative) - test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]] position = np.array([-1, -2, -4], dtype) length = np.array([1, 2, 3], dtype) - substr_op = string_ops.substr(test_string, position, length) + substr_op = string_ops.substr(test_string, position, length, unit=unit) with self.cached_session(): with self.assertRaises(errors_impl.InvalidArgumentError): - substr = substr_op.eval() - - def _testMismatchPosLenShapes(self, dtype): - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] + substr_op.eval() + + @parameterized.parameters( + (np.int32, "BYTE"), + (np.int64, "BYTE"), + (np.int32, "UTF8_CHAR"), + (np.int64, "UTF8_CHAR"), + ) + def testMismatchPosLenShapes(self, dtype, unit): + test_string = { + "BYTE": [[b"ten", b"eleven", b"twelve"], + [b"thirteen", b"fourteen", b"fifteen"], + [b"sixteen", b"seventeen", b"eighteen"]], + "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n", + u"\xc6\u053c\u025bv\u025bn", + u"tw\u0c1dlv\u025b"]], + [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean", + u"f\U0001f604urt\xea\xean", + u"f\xcd\ua09ctee\ua0e4"]], + [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean", + u"se\U00010299enteen", + u"ei\U0001e920h\x86een"]]], + }[unit] position = np.array([[1, 2, 3]], dtype) length = np.array([2, 3, 4], dtype) # Should fail: position/length have different rank with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) + string_ops.substr(test_string, position, length) position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype) length = np.array([[2, 3, 4]], dtype) # Should fail: position/length have different dimensionality with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) - - # Negative position. - test_string = [[b"ten", b"eleven", b"twelve"], - [b"thirteen", b"fourteen", b"fifteen"], - [b"sixteen", b"seventeen", b"eighteen"]] - position = np.array([[-1, -2, -3]], dtype) - length = np.array([1, 2, 3], dtype) - # Should fail: position/length have different rank - with self.assertRaises(ValueError): - substr_op = string_ops.substr(test_string, position, length) - - @parameterized.parameters(np.int32, np.int64) - def testAll(self, dtype): - self._testScalarString(dtype) - self._testVectorStrings(dtype) - self._testMatrixStrings(dtype) - self._testElementWisePosLen(dtype) - self._testBroadcast(dtype) - self._testBadBroadcast(dtype) - self._testOutOfRangeError(dtype) - self._testMismatchPosLenShapes(dtype) + string_ops.substr(test_string, position, length) def testWrongDtype(self): with self.cached_session(): @@ -300,6 +478,11 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase): with self.assertRaises(TypeError): string_ops.substr(b"test", 3, 1.0) + def testInvalidUnit(self): + with self.cached_session(): + with self.assertRaises(ValueError): + string_ops.substr(b"test", 3, 1, unit="UTF8") + if __name__ == "__main__": test.main() |