diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-02-25 10:39:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-25 11:15:56 -0800 |
commit | d1aed6505a7703ca9f596d8d415257fd238a62c6 (patch) | |
tree | de9af9bfdd74b9dea31b4a959ccb1e622c49561e /tensorflow/contrib/testing | |
parent | 9ccc4b6afe4defa748de89eac90ba1062232bb5a (diff) |
Add contrib/testing.
Change: 115578243
Diffstat (limited to 'tensorflow/contrib/testing')
-rw-r--r-- | tensorflow/contrib/testing/BUILD | 29 | ||||
-rw-r--r-- | tensorflow/contrib/testing/__init__.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/testing/python/framework/test_util.py | 118 |
3 files changed, 169 insertions, 0 deletions
diff --git a/tensorflow/contrib/testing/BUILD b/tensorflow/contrib/testing/BUILD new file mode 100644 index 0000000000..21d0b9610c --- /dev/null +++ b/tensorflow/contrib/testing/BUILD @@ -0,0 +1,29 @@ +# Description: +# contains parts of TensorFlow that are experimental or unstable and which are not supported. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +py_library( + name = "testing_py", + srcs = [ + "__init__.py", + "python/framework/test_util.py", + ], + srcs_version = "PY2AND3", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/testing/__init__.py b/tensorflow/contrib/testing/__init__.py new file mode 100644 index 0000000000..6807a7dd82 --- /dev/null +++ b/tensorflow/contrib/testing/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2015 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Testing utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.testing.python.framework.test_util import * diff --git a/tensorflow/contrib/testing/python/framework/test_util.py b/tensorflow/contrib/testing/python/framework/test_util.py new file mode 100644 index 0000000000..d48dcf8c62 --- /dev/null +++ b/tensorflow/contrib/testing/python/framework/test_util.py @@ -0,0 +1,118 @@ +"""Test utilities.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import glob +import os +import numpy as np +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.platform import logging +from tensorflow.python.training import summary_io + + +def assert_summary(expected_tags, expected_simple_values, summary_proto): + """Asserts summary contains the specified tags and values. + + Args: + expected_tags: All tags in summary. + expected_simple_values: Simply values for some tags. + summary_proto: Summary to validate. + + Raises: + ValueError: if expectations are not met. + """ + actual_tags = set() + for value in summary_proto.value: + actual_tags.add(value.tag) + if value.tag in expected_simple_values: + expected = expected_simple_values[value.tag] + actual = value.simple_value + np.testing.assert_almost_equal( + actual, expected, decimal=2, err_msg=value.tag) + expected_tags = set(expected_tags) + if expected_tags != actual_tags: + raise ValueError('Expected tags %s, got %s.' % (expected_tags, actual_tags)) + + +def to_summary_proto(summary_str): + """Create summary based on latest stats. + + Args: + summary_str: Serialized summary. + Returns: + summary_pb2.Summary. + Raises: + ValueError: if tensor is not a valid summary tensor. + """ + summary = summary_pb2.Summary() + summary.ParseFromString(summary_str) + return summary + + +# TODO(ptucker): Move to a non-test package? +def latest_event_file(base_dir): + """Find latest event file in `base_dir`. + + Args: + base_dir: Base directory in which TF event flies are stored. + Returns: + File path, or `None` if none exists. + """ + file_paths = glob.glob(os.path.join(base_dir, 'events.*')) + return sorted(file_paths)[-1] if file_paths else None + + +def latest_events(base_dir): + """Parse events from latest event file in base_dir. + + Args: + base_dir: Base directory in which TF event flies are stored. + Returns: + Iterable of event protos. + Raises: + ValueError: if no event files exist under base_dir. + """ + file_path = latest_event_file(base_dir) + return summary_io.summary_iterator(file_path) if file_path else [] + + +def latest_summaries(base_dir): + """Parse summary events from latest event file in base_dir. + + Args: + base_dir: Base directory in which TF event flies are stored. + Returns: + List of event protos. + Raises: + ValueError: if no event files exist under base_dir. + """ + return [e for e in latest_events(base_dir) if e.HasField('summary')] + + +def simple_values_from_events(events, tags): + """Parse summaries from events with simple_value. + + Args: + events: List of tensorflow.Event protos. + tags: List of string event tags corresponding to simple_value summaries. + Returns: + dict of tag:value. + Raises: + ValueError: if a summary with a specified tag does not contain simple_value. + """ + step_by_tag = {} + value_by_tag = {} + for e in events: + if e.HasField('summary'): + for v in e.summary.value: + tag = v.tag + if tag in tags: + if not v.HasField('simple_value'): + raise ValueError('Summary for %s is not a simple_value.' % tag) + # The events are mostly sorted in step order, but we explicitly check + # just in case. + if tag not in step_by_tag or e.step > step_by_tag[tag]: + step_by_tag[tag] = e.step + value_by_tag[tag] = v.simple_value + return value_by_tag + |