aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 11:30:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 11:34:46 -0700
commit700c3325311e16be9bb4856cbf944d1871ff35c1 (patch)
tree9ae88328889950abaa951a628de7212caec8c026 /tensorflow/python/kernel_tests
parentc8d5054e8c12800f0c3db0e51f3d5902e04eaa37 (diff)
Add "encoding" attribute to string substr op, which controls how each "character" is treated:
* BYTE: Position & length refer to bytes in the string. (Default) * UTF8: The string is interpreted as UTF-8 encoded Unicode code points, and position & length are treated relative to them. RELNOTES: Add option to get substring using Unicode characters PiperOrigin-RevId: 215773373
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/substr_op_test.py503
1 files changed, 343 insertions, 160 deletions
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index cd3fe14883..37aa624b07 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -28,270 +28,448 @@ from tensorflow.python.platform import test
class SubstrOpTest(test.TestCase, parameterized.TestCase):
- def _testScalarString(self, dtype):
- test_string = b"Hello"
- position = np.array(1, dtype)
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testScalarString(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_value = {
+ "BYTE": b"ell",
+ "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position.
- test_string = b"Hello"
- position = np.array(-4, dtype)
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testScalarString_EdgeCases(self, dtype, unit):
+ # Empty string
+ test_string = {
+ "BYTE": b"",
+ "UTF8_CHAR": u"".encode("utf-8"),
+ }[unit]
+ expected_value = b""
+ position = np.array(0, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Position is equal to the length of string.
- test_string = b""
+ # Full string
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
position = np.array(0, dtype)
- length = np.array(2, dtype)
- expected_value = b""
-
- substr_op = string_ops.substr(test_string, position, length)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position magnitude is equal to the length of string.
- test_string = b"yo"
- position = np.array(-2, dtype)
- length = np.array(1, dtype)
- expected_value = b"y"
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Full string (Negative)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-5, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- def _testVectorStrings(self, dtype):
- test_string = [b"Hello", b"World"]
- position = np.array(1, dtype)
- length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Length is larger in magnitude than a negative position
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_string = {
+ "BYTE": b"ello",
+ "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-4, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position.
- test_string = [b"Hello", b"World"]
- position = np.array(-4, dtype)
+ self.assertAllEqual(substr, expected_string)
+
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testVectorStrings(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": [b"Hello", b"World"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo",
+ u"W\U0001f604rld"]],
+ }[unit]
+ expected_value = {
+ "BYTE": [b"ell", b"orl"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testMatrixStrings(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMatrixStrings(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]]],
+ }[unit]
position = np.array(1, dtype)
length = np.array(4, dtype)
- expected_value = [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
- [b"ixte", b"even", b"ight"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
+ [b"ixte", b"even", b"ight"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u053c\u025bv\u025b",
+ u"w\u0c1dlv"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"\U0001f604rld",
+ u"\xfcd\xea"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array(-2, dtype)
+ position = np.array(-3, dtype)
length = np.array(2, dtype)
- expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"],
- [b"en", b"en", b"en"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"],
+ [b"ee", b"ee", b"ee"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227",
+ u"v\u025b", u"lv"]],
+ [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl",
+ u"\xfcd"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testElementWisePosLen(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testElementWisePosLen(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]],
+ [x.encode("utf-8") for x in [u"sixt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
- expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
- [b"xteen", b"vente", b"hteen"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
+ [b"xteen", b"vente", b"hteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u025bv",
+ u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"rld",
+ u"d\xfc"]],
+ [x.encode("utf-8") for x in [u"xt\xea\xean",
+ u"\U00010299ente",
+ u"h\x86een"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBroadcast(self, dtype, unit):
# Broadcast pos/len onto input string
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"],
- [b"nineteen", b"twenty", b"twentyone"]]
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"],
+ [b"nineteen", b"twenty", b"twentyone"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]],
+ [x.encode("utf-8") for x in [u"nineteen",
+ u"twenty",
+ u"twentyone"]]],
+ }[unit]
position = np.array([1, -4, 3], dtype)
length = np.array([1, 2, 3], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
- [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
+ [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227",
+ u"\u025bv", u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]],
+ [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]],
+ [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Broadcast input string onto pos/len
- test_string = [b"thirteen", b"fourteen", b"fifteen"]
+ test_string = {
+ "BYTE": [b"thirteen", b"fourteen", b"fifteen"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ }[unit]
position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
- [b"ee", b"ee", b"ft"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
+ [b"ee", b"ee", b"ft"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]],
+ [x.encode("utf-8") for x in [u"\xea", u"ur",
+ u"\xcd\ua09ct"]],
+ [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea",
+ u"\ua09ct"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Test 1D broadcast
- test_string = b"thirteen"
- position = np.array([1, -5, 7], dtype)
+ test_string = {
+ "BYTE": b"thirteen",
+ "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"),
+ }[unit]
+ position = np.array([1, -4, 7], dtype)
length = np.array([3, 2, 1], dtype)
- expected_value = [b"hir", b"rt", b"n"]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [b"hir", b"te", b"n"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBadBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBadBroadcast(self, dtype, unit):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
position = np.array([1, 2, -3, 4], dtype)
length = np.array([1, 2, 3, 4], dtype)
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- def _testOutOfRangeError(self, dtype):
+ string_ops.substr(test_string, position, length, unit=unit)
+
+ @parameterized.parameters(
+ (np.int32, 6, "BYTE"),
+ (np.int64, 6, "BYTE"),
+ (np.int32, -6, "BYTE"),
+ (np.int64, -6, "BYTE"),
+ (np.int32, 6, "UTF8_CHAR"),
+ (np.int64, 6, "UTF8_CHAR"),
+ (np.int32, -6, "UTF8_CHAR"),
+ (np.int64, -6, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Scalar(self, dtype, pos, unit):
# Scalar/Scalar
- test_string = b"Hello"
- position = np.array(7, dtype)
- length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Scalar/Scalar (with negative)
- test_string = b"Hello"
- position = np.array(-7, dtype)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, 4, "BYTE"),
+ (np.int64, 4, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 4, "UTF8_CHAR"),
+ (np.int64, 4, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_VectorScalar(self, dtype, pos, unit):
# Vector/Scalar
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(4, dtype)
- length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Vector/Scalar (with negative)
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(-4, dtype)
+ test_string = {
+ "BYTE": [b"good", b"good", b"bad", b"good"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d",
+ u"g\xc3\xc3d"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_MatrixMatrix(self, dtype, unit):
# Matrix/Matrix
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
+ [b"good", b"good", b"good"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]]],
+ }[unit]
position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Matrix/Matrix (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Broadcast(self, dtype, unit):
# Broadcast
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]]],
+ }[unit]
position = np.array([1, 2, 4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Broadcast (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
position = np.array([-1, -2, -4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- def _testMismatchPosLenShapes(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMismatchPosLenShapes(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, 2, 3]], dtype)
length = np.array([2, 3, 4], dtype)
# Should fail: position/length have different rank
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
+ string_ops.substr(test_string, position, length)
position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
length = np.array([[2, 3, 4]], dtype)
# Should fail: position/length have different dimensionality
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- # Negative position.
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([[-1, -2, -3]], dtype)
- length = np.array([1, 2, 3], dtype)
- # Should fail: position/length have different rank
- with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- @parameterized.parameters(np.int32, np.int64)
- def testAll(self, dtype):
- self._testScalarString(dtype)
- self._testVectorStrings(dtype)
- self._testMatrixStrings(dtype)
- self._testElementWisePosLen(dtype)
- self._testBroadcast(dtype)
- self._testBadBroadcast(dtype)
- self._testOutOfRangeError(dtype)
- self._testMismatchPosLenShapes(dtype)
+ string_ops.substr(test_string, position, length)
def testWrongDtype(self):
with self.cached_session():
@@ -300,6 +478,11 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3, 1.0)
+ def testInvalidUnit(self):
+ with self.cached_session():
+ with self.assertRaises(ValueError):
+ string_ops.substr(b"test", 3, 1, unit="UTF8")
+
if __name__ == "__main__":
test.main()