aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/parsing_ops_test.py
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-06 14:53:28 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-06 14:53:28 -0800
commitf9d3e9d03c69bfac77a2fe1ad80f7c5aa517e0f0 (patch)
tree52302a06eae969c8f4e1d7af6749a85fe0ac4eb1 /tensorflow/python/kernel_tests/parsing_ops_test.py
parent40d0d2904e8e00d3c4bf43fa62130eeebceef147 (diff)
TensorFlow: upstream latest changes to git.
Change 109537918 TensorFlow pip setup: wheel >= 0.26 for python3 pip install Change 109505848 Fix distortion default value to 1.0 in fixed_unigram_candidate_sampler. This means we default to the actual provided unigram distribution, instead of to the uniform (as it is currently). Change 109470494 Bugfix in gradients calculation when the ys rely on each other. Change 109467619 Fix CIFAR-10 model to train on all the training data instead of just 80% of it. Fixes #396. Change 109467557 Replaced checkpoint file with binary GraphDef. Change 109467433 Updates to C++ tutorial section. Change 109465269 TensorFlow: update documentation for tutorials to not assume use of bazel (when possible). Change 109462916 A tutorial for image recognition to coincide with the release of the latest Inception image classification model. Change 109462342 Clear control dependencies in variable_scope.get_variable() when creating ops for the initializer. Add tests of various error conditions. Change 109461981 Various performance improvements in low-level node execution code paths. Speeds up ptb_word_lm on my desktop with a Titan X from 3638 words per second to 3751 words per second (3.1% speedup). Changes include: o Avoided many strcmp operations per node execution and extra touches of cache lines in executor.cc, by making all the various IsMerge, IsSwitch, IsSend, etc. operations instead be based on an internal enum value that is pre-computed at Node construction time, rather than doing string comparisons against node->type_string(). We were doing about 6 such comparisons per executed node. o Removed mutex_lock in executor.cc in ExecutorState::Process. The lock was not needed and the comment about the iterations array being potentially resized is not true (the iterations arrays are created with a fixed size). Checked with yuanbyu to confirm this. o Added new two-argument port::Tracing::ScopedAnnotation constructor that takes two StringPiece arguments, and only concatenates them lazily if tracing is enabled. Also changed the code in platform/tracing.{h,cc} so that the ScopedAnnotation constructor and the TraceMe constructor can be inlined. o In BaseGPUDevice::Compute, used the two-argument ScopedAnnotation constructor to avoid doing StrCat(opkernel->name(), ":", op_kernel->type_string()) on every node execution on a GPU. o Introduced a new TensorReference class that just holds a reference to an underlying TensorBuffer, and requires an explicit Unref(). o Changed the EventMgr interface to take a vector of TensorReference objects for EventMgr::ThenDeleteTensors, rather than a vector of Tensor objects. o Used TensorReference in a few places in gpu_util.cc o Minor: switched to using InlinedVectors in a few places to get better cache locality. Change 109456692 Updated the label_image example to use the latest Inception model Change 109456545 Provides classify_image which performs image recognition on a 1000 object label set. $ ./classify_image giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca (score = 0.88493) indri, indris, Indri indri, Indri brevicaudatus (score = 0.00878) lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens (score = 0.00317) custard apple (score = 0.00149) earthstar (score = 0.00127) Change 109455002 TensorFlow: make the helper libraries for various models available in the pip package so that when users type: python translate.py ... the absolute import works. This change is supposed to help make our tutorials run without the *need* to use bazel. Change 109450041 TensorFlow: remove cifar and convolutional binary copies from pip install. Adds embedding and some other models to the list. Change 109448520 Move the description of a failing invariant from a comment into the dcheck-fail message text. Change 109447577 TensorBoard has release tagging (tensorboard/TAG) Also track TensorBoard changes (tensorboard/CHANGES) Change 109444161 Added ParseSingleSequenceExample + python wrappers + unit tests. Change 109440864 Update all the TensorFlow Dockerfiles, and simplify GPU containers. This change updates all four of our Dockerfiles to match the targets discussed in https://github.com/tensorflow/tensorflow/issues/149. The most notable change here is moving the GPU images to use the NVidia containers which include cudnn and other build-time dependencies, dramatically simplifying both the build and run steps. A description of which tags exist and get pushed where will be in a follow-up. Change 109432591 Some pylint and pydoc changes in saver. Change 109430127 Remove unused hydrogen components Change 109419354 The RNN api, although moved into python/ops/, remains undocumented. It may still change at any time. Base CL: 109538006
Diffstat (limited to 'tensorflow/python/kernel_tests/parsing_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/parsing_ops_test.py254
1 files changed, 246 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/kernel_tests/parsing_ops_test.py
index 22ab1716ca..490908120f 100644
--- a/tensorflow/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/kernel_tests/parsing_ops_test.py
@@ -95,11 +95,11 @@ class ParseExampleTest(tf.test.TestCase):
for (k, s) in zip(dense_keys, dense_shapes):
self.assertEqual(
tuple(out[k].get_shape().as_list()), (batch_size,) + s)
- for k in sparse_keys:
- self.assertEqual(
- tuple(out[k].indices.get_shape().as_list()), (None, 2))
- self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,))
- self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (2,))
+ for k in sparse_keys:
+ self.assertEqual(
+ tuple(out[k].indices.get_shape().as_list()), (None, 2))
+ self.assertEqual(tuple(out[k].values.get_shape().as_list()), (None,))
+ self.assertEqual(tuple(out[k].shape.get_shape().as_list()), (2,))
# Check values
result = flatten_values_tensors_or_sparse(out.values()) # flatten values
@@ -193,7 +193,7 @@ class ParseExampleTest(tf.test.TestCase):
"dense_types": [tf.float32],
"dense_shapes": dense_shapes,
},
- expected_err_re="Name: failing, Key: a. Number of float values")
+ expected_err_re="Name: failing, Key: a, Index: 1. Number of float val")
def testDenseDefaultNoShapeShouldFail(self):
original = [
@@ -211,7 +211,7 @@ class ParseExampleTest(tf.test.TestCase):
"dense_keys": ["a"],
"dense_types": [tf.float32],
},
- expected_err_re="Name: failing, Key: a. Number of float values")
+ expected_err_re="Name: failing, Key: a, Index: 0. Number of float val")
def testDenseDefaultNoShapeOk(self):
original = [
@@ -468,7 +468,7 @@ class ParseSingleExampleTest(tf.test.TestCase):
self._test(
{
- "names": "in1",
+ "names": tf.convert_to_tensor("in1"),
"serialized": tf.convert_to_tensor(serialized),
"dense_defaults": dense_defaults,
"dense_types": dense_types,
@@ -497,6 +497,244 @@ class ParseSequenceExampleTest(tf.test.TestCase):
}))
value.SerializeToString() # Smoke test
+ def _test(self, kwargs, expected_context_values=None,
+ expected_feat_list_values=None, expected_err_re=None):
+ expected_context_values = expected_context_values or {}
+ expected_feat_list_values = expected_feat_list_values or {}
+ with self.test_session() as sess:
+ # Pull out some keys to check shape inference
+ context_dense_keys = (
+ kwargs["context_dense_keys"]
+ if "context_dense_keys" in kwargs else [])
+ context_sparse_keys = (
+ kwargs["context_sparse_keys"]
+ if "context_sparse_keys" in kwargs else [])
+ context_dense_shapes = (
+ kwargs["context_dense_shapes"]
+ if "context_dense_shapes" in kwargs else [])
+ feature_list_dense_keys = (
+ kwargs["feature_list_dense_keys"]
+ if "feature_list_dense_keys" in kwargs else [])
+ feature_list_dense_shapes = (
+ kwargs["feature_list_dense_shapes"]
+ if "feature_list_dense_shapes" in kwargs else [])
+
+ # Returns dict w/ Tensors and SparseTensors
+ (context_out, feat_list_out) = tf.parse_single_sequence_example(**kwargs)
+
+ # Check shapes; if serialized is a Tensor we need its size to
+ # properly check.
+ if context_dense_shapes:
+ self.assertEqual(len(context_dense_keys), len(context_dense_shapes))
+ for (k, s) in zip(context_dense_keys, context_dense_shapes):
+ self.assertEqual(
+ tuple(context_out[k].get_shape().as_list()), s)
+ for k in context_sparse_keys:
+ self.assertEqual(
+ tuple(context_out[k].indices.get_shape().as_list()), (None, 1))
+ self.assertEqual(
+ tuple(context_out[k].values.get_shape().as_list()), (None,))
+ self.assertEqual(
+ tuple(context_out[k].shape.get_shape().as_list()), (1,))
+ if feature_list_dense_shapes:
+ self.assertEqual(
+ len(feature_list_dense_keys), len(feature_list_dense_shapes))
+ for (k, s) in zip(feature_list_dense_keys, feature_list_dense_shapes):
+ self.assertEqual(
+ tuple(feat_list_out[k].get_shape().as_list()), (None,) + s)
+
+ # Check values
+ context_result = flatten_values_tensors_or_sparse(
+ context_out.values()) # flatten values
+ feature_list_result = flatten_values_tensors_or_sparse(
+ feat_list_out.values())
+ if expected_err_re is None:
+ tf_context_result = sess.run(context_result)
+ tf_feat_list_result = sess.run(feature_list_result)
+ _compare_output_to_expected(
+ self, context_out, expected_context_values, tf_context_result)
+ _compare_output_to_expected(
+ self, feat_list_out, expected_feat_list_values, tf_feat_list_result)
+ else:
+ with self.assertRaisesOpError(expected_err_re):
+ sess.run(context_result)
+
+ def testSequenceExampleWithSparseAndDenseContext(self):
+ context_dense_types = [tf.int64, tf.string, tf.float32]
+ context_dense_shapes = [(1, 3), (3, 3), (2,)]
+ context_dense_defaults = {
+ "a": [1, 2, 3],
+ "b": np.random.rand(3, 3).astype(bytes),
+ # Feature "c" must be provided
+ }
+
+ original = sequence_example(context=features(
+ {"c": float_feature([3, 4]),
+ "st_a": float_feature([3.0, 4.0])}))
+
+ serialized = original.SerializeToString()
+
+ expected_st_a = (
+ np.array([[0], [1]], dtype=np.int64), # indices
+ np.array([3.0, 4.0], dtype=np.float32), # values
+ np.array([2], dtype=np.int64)) # shape: num_features = 2
+
+ expected_context_output = {
+ "st_a": expected_st_a,
+ "a": [context_dense_defaults["a"]],
+ "b": context_dense_defaults["b"],
+ "c": np.array([3, 4], dtype=np.float32),
+ }
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "context_dense_defaults": context_dense_defaults,
+ "context_dense_types": context_dense_types,
+ "context_sparse_keys": ["st_a"],
+ "context_sparse_types": [tf.float32],
+ "context_dense_keys": ["a", "b", "c"],
+ "context_dense_shapes": context_dense_shapes
+ }, expected_context_values=expected_context_output)
+
+ def testSequenceExampleWithMultipleSizeFeatureLists(self):
+ feature_list_dense_keys = ["a", "b", "c", "d"]
+ feature_list_dense_types = [tf.int64, tf.string, tf.float32, tf.float32]
+ feature_list_dense_shapes = [(1, 3), (2, 2), (2,), (5,)]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([-1, 0, 1]),
+ int64_feature([2, 3, 4]),
+ int64_feature([5, 6, 7]),
+ int64_feature([8, 9, 10]),]),
+ "b": feature_list([
+ bytes_feature(["r00", "r01", "r10", "r11"])]),
+ "c": feature_list([
+ float_feature([3, 4]),
+ float_feature([-1, 2])]),
+ }))
+
+ serialized = original.SerializeToString()
+
+ expected_feature_list_output = {
+ "a": np.array([ # outer dimension is time.
+ [[-1, 0, 1]], # inside are 1x3 matrices
+ [[2, 3, 4]],
+ [[5, 6, 7]],
+ [[8, 9, 10]]], dtype=np.int64),
+ "b": np.array([ # outer dimension is time, inside are 2x2 matrices
+ [["r00", "r01"], ["r10", "r11"]]], dtype=np.str),
+ "c": np.array([ # outer dimension is time, inside are 2-vectors
+ [3, 4],
+ [-1, 2]], dtype=np.float32),
+ "d": np.empty(shape=(0, 5), dtype=np.float32), # empty_allowed_missing
+ }
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_dense_types": feature_list_dense_types,
+ "feature_list_dense_keys": feature_list_dense_keys,
+ "feature_list_dense_shapes": feature_list_dense_shapes,
+ "feature_list_dense_defaults": {"d": None},
+ }, expected_feat_list_values=expected_feature_list_output)
+
+ def testSequenceExampleListWithInconsistentDataFails(self):
+ feature_list_dense_types = [tf.int64]
+ feature_list_dense_shapes = [(2,)]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([-1, 0]),
+ float_feature([2, 3])])
+ }))
+
+ serialized = original.SerializeToString()
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_dense_types": feature_list_dense_types,
+ "feature_list_dense_keys": ["a"],
+ "feature_list_dense_shapes": feature_list_dense_shapes
+ },
+ expected_err_re=("Feature list: a, Index: 1. Data types don't match. "
+ "Expected type: int64"))
+
+ def testSequenceExampleListWithWrongDataTypeFails(self):
+ feature_list_dense_types = [tf.int64]
+ feature_list_dense_shapes = [(2,)]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ float_feature([2, 3])])
+ }))
+
+ serialized = original.SerializeToString()
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_dense_types": feature_list_dense_types,
+ "feature_list_dense_keys": ["a"],
+ "feature_list_dense_shapes": feature_list_dense_shapes
+ },
+ expected_err_re=("Feature list: a, Index: 0. Data types don't match. "
+ "Expected type: int64"))
+
+ def testSequenceExampleListWithWrongShapeFails(self):
+ feature_list_dense_types = [tf.int64]
+ feature_list_dense_shapes = [(2,)]
+
+ original = sequence_example(feature_lists=feature_lists({
+ "a": feature_list([
+ int64_feature([2, 3]),
+ int64_feature([2, 3, 4])]),
+ }))
+
+ serialized = original.SerializeToString()
+
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_dense_types": feature_list_dense_types,
+ "feature_list_dense_keys": ["a"],
+ "feature_list_dense_shapes": feature_list_dense_shapes
+ },
+ expected_err_re=(r"Name: in1, Key: a, Index: 1. "
+ r"Number of int64 values != expected. "
+ r"values size: 3 but output shape: \[2\]"))
+
+ def testSequenceExampleWithMissingFeatureListFails(self):
+ feature_list_dense_types = [tf.int64]
+ feature_list_dense_shapes = [(2,)]
+
+ original = sequence_example(feature_lists=feature_lists({}))
+
+ serialized = original.SerializeToString()
+
+ # Test fails because we didn't add:
+ # feature_list_dense_defaults = {"a": None}
+ self._test(
+ {
+ "debug_name": "in1",
+ "serialized": tf.convert_to_tensor(serialized),
+ "feature_list_dense_types": feature_list_dense_types,
+ "feature_list_dense_keys": ["a"],
+ "feature_list_dense_shapes": feature_list_dense_shapes
+ },
+ expected_err_re=(
+ "Name: in1, Feature list 'a' is required but could not be found. "
+ "Did you mean to include it in "
+ "feature_list_dense_missing_assumed_empty or "
+ "feature_list_dense_defaults?"))
+
if __name__ == "__main__":
tf.test.main()