aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/input_pipeline
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2016-12-15 18:56:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 19:06:46 -0800
commit810b3537ff161065c2f3073411e2e16be5cc1b0f (patch)
tree1c8e6191a5668ff99a7eaeb4389279d170263d0c /tensorflow/contrib/input_pipeline
parent9eca524593029b3adb9f1a83fc4fb09f2ad58957 (diff)
Adds num_epochs argument to input_pipeline_ops.seek_next to limit number of times we see each string.
Also changes the shape of the input / output for the ObtainNext Op from [1] shapes to []'s instead which are more intuitive. Then we migrate *_shared_queue functions in graph_io to use seek_next instead. Change: 142216697
Diffstat (limited to 'tensorflow/contrib/input_pipeline')
-rw-r--r--tensorflow/contrib/input_pipeline/BUILD2
-rw-r--r--tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc13
-rw-r--r--tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc6
-rw-r--r--tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc6
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py53
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py (renamed from tensorflow/contrib/input_pipeline/python/kernel_tests/input_pipeline_ops_test.py)40
6 files changed, 89 insertions, 31 deletions
diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD
index a9cf9160b3..6b88b6eb4c 100644
--- a/tensorflow/contrib/input_pipeline/BUILD
+++ b/tensorflow/contrib/input_pipeline/BUILD
@@ -68,7 +68,7 @@ py_library(
py_test(
name = "input_pipeline_ops_test",
size = "small",
- srcs = ["python/kernel_tests/input_pipeline_ops_test.py"],
+ srcs = ["python/ops/input_pipeline_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":input_pipeline_py",
diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc
index 3f46dabaaf..ca288c1f73 100644
--- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc
+++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc
@@ -34,8 +34,9 @@ class ObtainNextOp : public OpKernel {
// Allocate output.
Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output("out_element", TensorShape({1}),
- &output_tensor));
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->allocate_output("out_element", TensorShape({}), &output_tensor));
// Obtain mutex for the "counter" tensor.
mutex* mu;
@@ -44,13 +45,11 @@ class ObtainNextOp : public OpKernel {
// Increment "counter" tensor by 1.
Tensor counter_tensor;
OP_REQUIRES_OK(ctx, ctx->mutable_input("counter", &counter_tensor, true));
- auto counter_tensor_flat = counter_tensor.flat<int64>();
- int64& pos = counter_tensor_flat(0);
- pos = (pos + 1) % num_elements;
+ int64* pos = &counter_tensor.scalar<int64>()();
+ *pos = (*pos + 1) % num_elements;
// Assign value to output.
- auto output_tensor_flat = output_tensor->flat<string>();
- output_tensor_flat(0) = list_flat(pos);
+ output_tensor->scalar<string>()() = list_flat(*pos);
}
};
diff --git a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc
index 05639ed702..052dbfec33 100644
--- a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc
+++ b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc
@@ -29,10 +29,8 @@ REGISTER_OP("ObtainNext")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused_input, input1;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused_input));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &input1));
- DimensionHandle unused_dim;
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input1, 0), 1, &unused_dim));
- c->set_output(0, c->Vector(1));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &input1));
+ c->set_output(0, c->Scalar());
return Status::OK();
})
.Doc(R"doc(
diff --git a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc
index b644c23c7f..fbfed20748 100644
--- a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc
+++ b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc
@@ -29,10 +29,10 @@ TEST(InputPipelineOpsTest, ObtainNext_InvalidNumberOfInputs) {
TEST(InputPipelineOpsTest, ObtainNext) {
ShapeInferenceTestOp op("ObtainNext");
- INFER_OK(op, "[100];[1]", "[1]");
+ INFER_OK(op, "[100];[]", "[]");
- INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1];[1]");
- INFER_ERROR("Dimension must be 1 but is 2", op, "[1000];[2]");
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1];[]");
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1000];[1]");
}
} // end namespace tensorflow
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py
index d8c95d9c39..edd6d22c5f 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py
@@ -17,11 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import random
+
from tensorflow.contrib.util import loader
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import resource_loader
@@ -44,26 +45,60 @@ def obtain_next(string_list_tensor, counter):
return _input_pipeline_ops.obtain_next(string_list_tensor, counter)
-def seek_next(string_list):
+def _maybe_randomize_list(string_list, shuffle):
+ if shuffle:
+ random.shuffle(string_list)
+ return string_list
+
+
+def _create_list(string_list, shuffle, seed, num_epochs):
+ if shuffle and seed:
+ random.seed(seed)
+ expanded_list = _maybe_randomize_list(string_list, shuffle)
+ if num_epochs:
+ for _ in range(num_epochs - 1):
+ expanded_list.extend(_maybe_randomize_list(string_list, shuffle))
+ return expanded_list
+
+
+def seek_next(string_list, shuffle=False, seed=None, num_epochs=None):
"""Returns an op that seeks the next element in a list of strings.
Seeking happens in a round robin fashion. This op creates a variable called
counter that is initialized to -1 and is used to keep track of which element
- in the list was returned.
+ in the list was returned. If num_epochs is not None, then we limit the number
+ of times we go around the string_list before OutOfRangeError is thrown. It
+ creates a variable to keep track of this.
Args:
- string_list: A list of strings
+ string_list: A list of strings.
+ shuffle: If true, we shuffle the string_list differently for each epoch.
+ seed: Seed used for shuffling.
+ num_epochs: Returns OutOfRangeError once string_list has been repeated
+ num_epoch times. If unspecified then keeps on looping.
Returns:
An op that produces the next element in the provided list.
"""
+ expanded_list = _create_list(string_list, shuffle, seed, num_epochs)
+
with variable_scope.variable_scope("obtain_next"):
counter = variable_scope.get_variable(
name="obtain_next_counter",
- initializer=constant_op.constant([-1], dtype=dtypes.int64),
+ initializer=constant_op.constant(
+ -1, dtype=dtypes.int64),
dtype=dtypes.int64)
with ops.device(counter.device):
- string_tensor = constant_op.constant(string_list,
- name="obtain_next_string_list")
- return obtain_next(string_tensor, counter)
-
+ string_tensor = constant_op.constant(
+ expanded_list, name="obtain_next_expanded_list")
+ if num_epochs:
+ filename_counter = variable_scope.get_variable(
+ name="obtain_next_filename_counter",
+ initializer=constant_op.constant(
+ 0, dtype=dtypes.int64),
+ dtype=dtypes.int64)
+ c = filename_counter.count_up_to(len(expanded_list))
+ with ops.control_dependencies([c]):
+ return obtain_next(string_tensor, counter)
+ else:
+ return obtain_next(string_tensor, counter)
diff --git a/tensorflow/contrib/input_pipeline/python/kernel_tests/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index b8f1d9c0e4..00467b8cc9 100644
--- a/tensorflow/contrib/input_pipeline/python/kernel_tests/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -20,37 +20,63 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.input_pipeline.python.ops import input_pipeline_ops
+from tensorflow.python.framework import errors
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables as var_ops
class InputPipelineOpsTest(tf.test.TestCase):
def testObtainNext(self):
with self.test_session():
- var = state_ops.variable_op([1], tf.int64)
- tf.assign(var, [-1]).op.run()
+ var = state_ops.variable_op([], tf.int64)
+ tf.assign(var, -1).op.run()
c = tf.constant(["a", "b"])
sample1 = input_pipeline_ops.obtain_next(c, var)
self.assertEqual(b"a", sample1.eval())
- self.assertEqual([0], var.eval())
+ self.assertEqual(0, var.eval())
sample2 = input_pipeline_ops.obtain_next(c, var)
self.assertEqual(b"b", sample2.eval())
- self.assertEqual([1], var.eval())
+ self.assertEqual(1, var.eval())
sample3 = input_pipeline_ops.obtain_next(c, var)
self.assertEqual(b"a", sample3.eval())
- self.assertEqual([0], var.eval())
+ self.assertEqual(0, var.eval())
def testSeekNext(self):
string_list = ["a", "b", "c"]
with self.test_session() as session:
elem = input_pipeline_ops.seek_next(string_list)
- session.run(tf.initialize_all_variables())
+ session.run([tf.global_variables_initializer()])
self.assertEqual(b"a", session.run(elem))
self.assertEqual(b"b", session.run(elem))
self.assertEqual(b"c", session.run(elem))
+ # Make sure we loop.
self.assertEqual(b"a", session.run(elem))
+ # Helper method that runs the op len(expected_list) number of times, asserts
+ # that the results are elements of the expected_list and then throws an
+ # OutOfRangeError.
+ def _assert_output(self, expected_list, session, op):
+ for element in expected_list:
+ self.assertEqual(element, session.run(op))
+ with self.assertRaises(errors.OutOfRangeError):
+ session.run(op)
+
+ def testSeekNextLimitEpochs(self):
+ string_list = ["a", "b", "c"]
+ with self.test_session() as session:
+ elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
+ session.run(
+ [tf.local_variables_initializer(), tf.global_variables_initializer()])
+ self._assert_output([b"a", b"b", b"c"], session, elem)
+
+ def testSeekNextLimitEpochsTwo(self):
+ string_list = ["a", "b", "c"]
+ with self.test_session() as session:
+ elem = input_pipeline_ops.seek_next(string_list, num_epochs=2)
+ session.run(
+ [tf.local_variables_initializer(), tf.global_variables_initializer()])
+ # Expect to see [a, b, c] two times.
+ self._assert_output([b"a", b"b", b"c"] * 2, session, elem)
if __name__ == "__main__":
tf.test.main()