aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/util
diff options
context:
space:
mode:
authorGravatar Dan Ringwalt <ringwalt@google.com>2017-05-12 08:22:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-12 08:25:52 -0700
commit733bff53926717bb9583d4833ba062c58f27960f (patch)
tree8d99c52b2d2079111e60982b4e8bb0a9e4440481 /tensorflow/contrib/util
parent53dc8ab2a14f913583cc259930f35d777bc7cc81 (diff)
Add a tf.contrib.util.create_example utility for building Example protos.
PiperOrigin-RevId: 155868794
Diffstat (limited to 'tensorflow/contrib/util')
-rw-r--r--tensorflow/contrib/util/BUILD19
-rw-r--r--tensorflow/contrib/util/__init__.py3
-rw-r--r--tensorflow/contrib/util/create_example.py61
-rw-r--r--tensorflow/contrib/util/create_example_test.py86
4 files changed, 167 insertions, 2 deletions
diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD
index 5ad8e3dd35..a6d3ca7242 100644
--- a/tensorflow/contrib/util/BUILD
+++ b/tensorflow/contrib/util/BUILD
@@ -64,12 +64,29 @@ cc_binary(
py_library(
name = "util_py",
- srcs = glob(["**/*.py"]),
+ srcs = glob(
+ ["**/*.py"],
+ exclude = ["**/*_test.py"],
+ ),
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:framework",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "create_example_test",
+ srcs = ["create_example_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":util_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/util/__init__.py b/tensorflow/contrib/util/__init__.py
index 08741cf8ca..d976a3f208 100644
--- a/tensorflow/contrib/util/__init__.py
+++ b/tensorflow/contrib/util/__init__.py
@@ -18,6 +18,7 @@
See @{$python/contrib.util} guide.
@@constant_value
+@@create_example
@@make_tensor_proto
@@make_ndarray
@@ops_used_by_graph_def
@@ -30,11 +31,11 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
+from tensorflow.contrib.util.create_example import create_example
from tensorflow.python.framework.meta_graph import ops_used_by_graph_def
from tensorflow.python.framework.meta_graph import stripped_op_list_for_graph
from tensorflow.python.framework.tensor_util import constant_value
from tensorflow.python.framework.tensor_util import make_tensor_proto
from tensorflow.python.framework.tensor_util import MakeNdarray as make_ndarray
-# pylint: disable=unused_import
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__)
diff --git a/tensorflow/contrib/util/create_example.py b/tensorflow/contrib/util/create_example.py
new file mode 100644
index 0000000000..8cb6e8809b
--- /dev/null
+++ b/tensorflow/contrib/util/create_example.py
@@ -0,0 +1,61 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for constructing Example protos.
+
+Takes ndarrays, lists, or tuples for each feature.
+
+@@create_example
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+
+
+def create_example(**features):
+ """Constructs a `tf.train.Example` from the given features.
+
+ Args:
+ **features: Maps feature name to an integer, float, or string ndarray, or
+ another object convertible to an ndarray (list, tuple, etc).
+
+ Returns:
+ A `tf.train.Example` with the features.
+
+ Raises:
+ ValueError: if a feature is not integer, float, or string.
+ """
+ example = example_pb2.Example()
+ for name in features:
+ feature = example.features.feature[name]
+ values = np.asarray(features[name])
+ # Encode unicode using UTF-8.
+ if values.dtype.kind == 'U':
+ values = np.vectorize(lambda string: string.encode('utf-8'))(values)
+
+ if values.dtype.kind == 'i':
+ feature.int64_list.value.extend(values.astype(np.int64).ravel())
+ elif values.dtype.kind == 'f':
+ feature.float_list.value.extend(values.astype(np.float32).ravel())
+ elif values.dtype.kind == 'S':
+ feature.bytes_list.value.extend(values.ravel())
+ else:
+ raise ValueError('Feature "%s" has unexpected dtype: %s' % (name,
+ values.dtype))
+ return example
diff --git a/tensorflow/contrib/util/create_example_test.py b/tensorflow/contrib/util/create_example_test.py
new file mode 100644
index 0000000000..091b6c75d0
--- /dev/null
+++ b/tensorflow/contrib/util/create_example_test.py
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the Example creation utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib import util
+from tensorflow.core.example import example_pb2
+from tensorflow.python.platform import googletest
+
+
+class CreateExampleTest(googletest.TestCase):
+
+ def testCreateExample_empty(self):
+ self.assertEqual(util.create_example(), example_pb2.Example())
+
+ # np.asarray([]) == np.array([], dtype=np.float64), but the dtype should not
+ # matter here.
+ actual = util.create_example(foo=[], bar=())
+ expected = example_pb2.Example()
+ expected.features.feature['foo'].float_list.value.extend([])
+ expected.features.feature['bar'].float_list.value.extend([])
+ self.assertEqual(actual, expected)
+
+ def testCreateExample_scalars(self):
+ actual = util.create_example(foo=3, bar=4.2, baz='x', qux=b'y')
+ expected = example_pb2.Example()
+ expected.features.feature['foo'].int64_list.value.append(3)
+ # 4.2 cannot be represented exactly in floating point.
+ expected.features.feature['bar'].float_list.value.append(np.float32(4.2))
+ expected.features.feature['baz'].bytes_list.value.append(b'x')
+ expected.features.feature['qux'].bytes_list.value.append(b'y')
+ self.assertEqual(actual, expected)
+
+ def testCreateExample_listContainingString(self):
+ actual = util.create_example(foo=[3, 4.2, 'foo'])
+ # np.asarray([3, 4.2, 'foo']) == np.array(['3', '4.2', 'foo'])
+ expected = example_pb2.Example()
+ expected.features.feature['foo'].bytes_list.value.extend(
+ [b'3', b'4.2', b'foo'])
+ self.assertEqual(actual, expected)
+
+ def testCreateExample_lists_tuples_ranges(self):
+ actual = util.create_example(
+ foo=[1, 2, 3, 4, 5], bar=(0.5, 0.25, 0.125), baz=range(3))
+ expected = example_pb2.Example()
+ expected.features.feature['foo'].int64_list.value.extend([1, 2, 3, 4, 5])
+ expected.features.feature['bar'].float_list.value.extend([0.5, 0.25, 0.125])
+ expected.features.feature['baz'].int64_list.value.extend([0, 1, 2])
+ self.assertEqual(actual, expected)
+
+ def testCreateExample_ndarrays(self):
+ a = np.random.random((3, 4, 5)).astype(np.float32)
+ b = np.random.randint(low=1, high=10, size=(6, 5, 4))
+ actual = util.create_example(A=a, B=b)
+ expected = example_pb2.Example()
+ expected.features.feature['A'].float_list.value.extend(a.ravel())
+ expected.features.feature['B'].int64_list.value.extend(b.ravel())
+ self.assertEqual(actual, expected)
+
+ def testCreateExample_unicode(self):
+ actual = util.create_example(A=[u'\u4242', u'\u5555'])
+ expected = example_pb2.Example()
+ expected.features.feature['A'].bytes_list.value.extend(
+ [u'\u4242'.encode('utf-8'), u'\u5555'.encode('utf-8')])
+ self.assertEqual(actual, expected)
+
+
+if __name__ == '__main__':
+ googletest.main()