aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/testing
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-02-25 10:39:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-25 11:15:56 -0800
commitd1aed6505a7703ca9f596d8d415257fd238a62c6 (patch)
treede9af9bfdd74b9dea31b4a959ccb1e622c49561e /tensorflow/contrib/testing
parent9ccc4b6afe4defa748de89eac90ba1062232bb5a (diff)
Add contrib/testing.
Change: 115578243
Diffstat (limited to 'tensorflow/contrib/testing')
-rw-r--r--tensorflow/contrib/testing/BUILD29
-rw-r--r--tensorflow/contrib/testing/__init__.py22
-rw-r--r--tensorflow/contrib/testing/python/framework/test_util.py118
3 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/contrib/testing/BUILD b/tensorflow/contrib/testing/BUILD
new file mode 100644
index 0000000000..21d0b9610c
--- /dev/null
+++ b/tensorflow/contrib/testing/BUILD
@@ -0,0 +1,29 @@
+# Description:
+# contains parts of TensorFlow that are experimental or unstable and which are not supported.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//tensorflow:__subpackages__"])
+
+py_library(
+ name = "testing_py",
+ srcs = [
+ "__init__.py",
+ "python/framework/test_util.py",
+ ],
+ srcs_version = "PY2AND3",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/testing/__init__.py b/tensorflow/contrib/testing/__init__.py
new file mode 100644
index 0000000000..6807a7dd82
--- /dev/null
+++ b/tensorflow/contrib/testing/__init__.py
@@ -0,0 +1,22 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Testing utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.testing.python.framework.test_util import *
diff --git a/tensorflow/contrib/testing/python/framework/test_util.py b/tensorflow/contrib/testing/python/framework/test_util.py
new file mode 100644
index 0000000000..d48dcf8c62
--- /dev/null
+++ b/tensorflow/contrib/testing/python/framework/test_util.py
@@ -0,0 +1,118 @@
+"""Test utilities."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import glob
+import os
+import numpy as np
+from tensorflow.core.framework import summary_pb2
+from tensorflow.python.platform import logging
+from tensorflow.python.training import summary_io
+
+
+def assert_summary(expected_tags, expected_simple_values, summary_proto):
+ """Asserts summary contains the specified tags and values.
+
+ Args:
+ expected_tags: All tags in summary.
+ expected_simple_values: Simply values for some tags.
+ summary_proto: Summary to validate.
+
+ Raises:
+ ValueError: if expectations are not met.
+ """
+ actual_tags = set()
+ for value in summary_proto.value:
+ actual_tags.add(value.tag)
+ if value.tag in expected_simple_values:
+ expected = expected_simple_values[value.tag]
+ actual = value.simple_value
+ np.testing.assert_almost_equal(
+ actual, expected, decimal=2, err_msg=value.tag)
+ expected_tags = set(expected_tags)
+ if expected_tags != actual_tags:
+ raise ValueError('Expected tags %s, got %s.' % (expected_tags, actual_tags))
+
+
+def to_summary_proto(summary_str):
+ """Create summary based on latest stats.
+
+ Args:
+ summary_str: Serialized summary.
+ Returns:
+ summary_pb2.Summary.
+ Raises:
+ ValueError: if tensor is not a valid summary tensor.
+ """
+ summary = summary_pb2.Summary()
+ summary.ParseFromString(summary_str)
+ return summary
+
+
+# TODO(ptucker): Move to a non-test package?
+def latest_event_file(base_dir):
+ """Find latest event file in `base_dir`.
+
+ Args:
+ base_dir: Base directory in which TF event flies are stored.
+ Returns:
+ File path, or `None` if none exists.
+ """
+ file_paths = glob.glob(os.path.join(base_dir, 'events.*'))
+ return sorted(file_paths)[-1] if file_paths else None
+
+
+def latest_events(base_dir):
+ """Parse events from latest event file in base_dir.
+
+ Args:
+ base_dir: Base directory in which TF event flies are stored.
+ Returns:
+ Iterable of event protos.
+ Raises:
+ ValueError: if no event files exist under base_dir.
+ """
+ file_path = latest_event_file(base_dir)
+ return summary_io.summary_iterator(file_path) if file_path else []
+
+
+def latest_summaries(base_dir):
+ """Parse summary events from latest event file in base_dir.
+
+ Args:
+ base_dir: Base directory in which TF event flies are stored.
+ Returns:
+ List of event protos.
+ Raises:
+ ValueError: if no event files exist under base_dir.
+ """
+ return [e for e in latest_events(base_dir) if e.HasField('summary')]
+
+
+def simple_values_from_events(events, tags):
+ """Parse summaries from events with simple_value.
+
+ Args:
+ events: List of tensorflow.Event protos.
+ tags: List of string event tags corresponding to simple_value summaries.
+ Returns:
+ dict of tag:value.
+ Raises:
+ ValueError: if a summary with a specified tag does not contain simple_value.
+ """
+ step_by_tag = {}
+ value_by_tag = {}
+ for e in events:
+ if e.HasField('summary'):
+ for v in e.summary.value:
+ tag = v.tag
+ if tag in tags:
+ if not v.HasField('simple_value'):
+ raise ValueError('Summary for %s is not a simple_value.' % tag)
+ # The events are mostly sorted in step order, but we explicitly check
+ # just in case.
+ if tag not in step_by_tag or e.step > step_by_tag[tag]:
+ step_by_tag[tag] = e.step
+ value_by_tag[tag] = v.simple_value
+ return value_by_tag
+