diff options
author | Dan Ringwalt <ringwalt@google.com> | 2017-05-12 08:22:25 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-12 08:25:52 -0700 |
commit | 733bff53926717bb9583d4833ba062c58f27960f (patch) | |
tree | 8d99c52b2d2079111e60982b4e8bb0a9e4440481 /tensorflow/contrib/util | |
parent | 53dc8ab2a14f913583cc259930f35d777bc7cc81 (diff) |
Add a tf.contrib.util.create_example utility for building Example protos.
PiperOrigin-RevId: 155868794
Diffstat (limited to 'tensorflow/contrib/util')
-rw-r--r-- | tensorflow/contrib/util/BUILD | 19 | ||||
-rw-r--r-- | tensorflow/contrib/util/__init__.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/util/create_example.py | 61 | ||||
-rw-r--r-- | tensorflow/contrib/util/create_example_test.py | 86 |
4 files changed, 167 insertions, 2 deletions
diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 5ad8e3dd35..a6d3ca7242 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -64,12 +64,29 @@ cc_binary( py_library( name = "util_py", - srcs = glob(["**/*.py"]), + srcs = glob( + ["**/*.py"], + exclude = ["**/*_test.py"], + ), srcs_version = "PY2AND3", deps = [ + "//tensorflow/core:protos_all_py", "//tensorflow/python:framework", "//tensorflow/python:platform", "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_test( + name = "create_example_test", + srcs = ["create_example_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":util_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform_test", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/util/__init__.py b/tensorflow/contrib/util/__init__.py index 08741cf8ca..d976a3f208 100644 --- a/tensorflow/contrib/util/__init__.py +++ b/tensorflow/contrib/util/__init__.py @@ -18,6 +18,7 @@ See @{$python/contrib.util} guide. @@constant_value +@@create_example @@make_tensor_proto @@make_ndarray @@ops_used_by_graph_def @@ -30,11 +31,11 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import +from tensorflow.contrib.util.create_example import create_example from tensorflow.python.framework.meta_graph import ops_used_by_graph_def from tensorflow.python.framework.meta_graph import stripped_op_list_for_graph from tensorflow.python.framework.tensor_util import constant_value from tensorflow.python.framework.tensor_util import make_tensor_proto from tensorflow.python.framework.tensor_util import MakeNdarray as make_ndarray -# pylint: disable=unused_import from tensorflow.python.util.all_util import remove_undocumented remove_undocumented(__name__) diff --git a/tensorflow/contrib/util/create_example.py b/tensorflow/contrib/util/create_example.py new file mode 100644 index 0000000000..8cb6e8809b --- /dev/null +++ b/tensorflow/contrib/util/create_example.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for constructing Example protos. + +Takes ndarrays, lists, or tuples for each feature. + +@@create_example +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.core.example import example_pb2 + + +def create_example(**features): + """Constructs a `tf.train.Example` from the given features. + + Args: + **features: Maps feature name to an integer, float, or string ndarray, or + another object convertible to an ndarray (list, tuple, etc). + + Returns: + A `tf.train.Example` with the features. + + Raises: + ValueError: if a feature is not integer, float, or string. + """ + example = example_pb2.Example() + for name in features: + feature = example.features.feature[name] + values = np.asarray(features[name]) + # Encode unicode using UTF-8. + if values.dtype.kind == 'U': + values = np.vectorize(lambda string: string.encode('utf-8'))(values) + + if values.dtype.kind == 'i': + feature.int64_list.value.extend(values.astype(np.int64).ravel()) + elif values.dtype.kind == 'f': + feature.float_list.value.extend(values.astype(np.float32).ravel()) + elif values.dtype.kind == 'S': + feature.bytes_list.value.extend(values.ravel()) + else: + raise ValueError('Feature "%s" has unexpected dtype: %s' % (name, + values.dtype)) + return example diff --git a/tensorflow/contrib/util/create_example_test.py b/tensorflow/contrib/util/create_example_test.py new file mode 100644 index 0000000000..091b6c75d0 --- /dev/null +++ b/tensorflow/contrib/util/create_example_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Example creation utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import util +from tensorflow.core.example import example_pb2 +from tensorflow.python.platform import googletest + + +class CreateExampleTest(googletest.TestCase): + + def testCreateExample_empty(self): + self.assertEqual(util.create_example(), example_pb2.Example()) + + # np.asarray([]) == np.array([], dtype=np.float64), but the dtype should not + # matter here. + actual = util.create_example(foo=[], bar=()) + expected = example_pb2.Example() + expected.features.feature['foo'].float_list.value.extend([]) + expected.features.feature['bar'].float_list.value.extend([]) + self.assertEqual(actual, expected) + + def testCreateExample_scalars(self): + actual = util.create_example(foo=3, bar=4.2, baz='x', qux=b'y') + expected = example_pb2.Example() + expected.features.feature['foo'].int64_list.value.append(3) + # 4.2 cannot be represented exactly in floating point. + expected.features.feature['bar'].float_list.value.append(np.float32(4.2)) + expected.features.feature['baz'].bytes_list.value.append(b'x') + expected.features.feature['qux'].bytes_list.value.append(b'y') + self.assertEqual(actual, expected) + + def testCreateExample_listContainingString(self): + actual = util.create_example(foo=[3, 4.2, 'foo']) + # np.asarray([3, 4.2, 'foo']) == np.array(['3', '4.2', 'foo']) + expected = example_pb2.Example() + expected.features.feature['foo'].bytes_list.value.extend( + [b'3', b'4.2', b'foo']) + self.assertEqual(actual, expected) + + def testCreateExample_lists_tuples_ranges(self): + actual = util.create_example( + foo=[1, 2, 3, 4, 5], bar=(0.5, 0.25, 0.125), baz=range(3)) + expected = example_pb2.Example() + expected.features.feature['foo'].int64_list.value.extend([1, 2, 3, 4, 5]) + expected.features.feature['bar'].float_list.value.extend([0.5, 0.25, 0.125]) + expected.features.feature['baz'].int64_list.value.extend([0, 1, 2]) + self.assertEqual(actual, expected) + + def testCreateExample_ndarrays(self): + a = np.random.random((3, 4, 5)).astype(np.float32) + b = np.random.randint(low=1, high=10, size=(6, 5, 4)) + actual = util.create_example(A=a, B=b) + expected = example_pb2.Example() + expected.features.feature['A'].float_list.value.extend(a.ravel()) + expected.features.feature['B'].int64_list.value.extend(b.ravel()) + self.assertEqual(actual, expected) + + def testCreateExample_unicode(self): + actual = util.create_example(A=[u'\u4242', u'\u5555']) + expected = example_pb2.Example() + expected.features.feature['A'].bytes_list.value.extend( + [u'\u4242'.encode('utf-8'), u'\u5555'.encode('utf-8')]) + self.assertEqual(actual, expected) + + +if __name__ == '__main__': + googletest.main() |