diff options
author | Rohan Jain <rohanj@google.com> | 2016-12-15 18:56:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-15 19:06:46 -0800 |
commit | 810b3537ff161065c2f3073411e2e16be5cc1b0f (patch) | |
tree | 1c8e6191a5668ff99a7eaeb4389279d170263d0c /tensorflow/contrib/input_pipeline | |
parent | 9eca524593029b3adb9f1a83fc4fb09f2ad58957 (diff) |
Adds num_epochs argument to input_pipeline_ops.seek_next to limit number of times we see each string.
Also changes the shape of the input / output for the ObtainNext Op from [1] shapes to []'s instead which are more intuitive.
Then we migrate *_shared_queue functions in graph_io to use seek_next instead.
Change: 142216697
Diffstat (limited to 'tensorflow/contrib/input_pipeline')
-rw-r--r-- | tensorflow/contrib/input_pipeline/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc | 13 | ||||
-rw-r--r-- | tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py | 53 | ||||
-rw-r--r-- | tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py (renamed from tensorflow/contrib/input_pipeline/python/kernel_tests/input_pipeline_ops_test.py) | 40 |
6 files changed, 89 insertions, 31 deletions
diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index a9cf9160b3..6b88b6eb4c 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -68,7 +68,7 @@ py_library( py_test( name = "input_pipeline_ops_test", size = "small", - srcs = ["python/kernel_tests/input_pipeline_ops_test.py"], + srcs = ["python/ops/input_pipeline_ops_test.py"], srcs_version = "PY2AND3", deps = [ ":input_pipeline_py", diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc index 3f46dabaaf..ca288c1f73 100644 --- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc +++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc @@ -34,8 +34,9 @@ class ObtainNextOp : public OpKernel { // Allocate output. Tensor* output_tensor = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output("out_element", TensorShape({1}), - &output_tensor)); + OP_REQUIRES_OK( + ctx, + ctx->allocate_output("out_element", TensorShape({}), &output_tensor)); // Obtain mutex for the "counter" tensor. mutex* mu; @@ -44,13 +45,11 @@ class ObtainNextOp : public OpKernel { // Increment "counter" tensor by 1. Tensor counter_tensor; OP_REQUIRES_OK(ctx, ctx->mutable_input("counter", &counter_tensor, true)); - auto counter_tensor_flat = counter_tensor.flat<int64>(); - int64& pos = counter_tensor_flat(0); - pos = (pos + 1) % num_elements; + int64* pos = &counter_tensor.scalar<int64>()(); + *pos = (*pos + 1) % num_elements; // Assign value to output. - auto output_tensor_flat = output_tensor->flat<string>(); - output_tensor_flat(0) = list_flat(pos); + output_tensor->scalar<string>()() = list_flat(*pos); } }; diff --git a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc index 05639ed702..052dbfec33 100644 --- a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc +++ b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc @@ -29,10 +29,8 @@ REGISTER_OP("ObtainNext") .SetShapeFn([](InferenceContext* c) { ShapeHandle unused_input, input1; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused_input)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &input1)); - DimensionHandle unused_dim; - TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input1, 0), 1, &unused_dim)); - c->set_output(0, c->Vector(1)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &input1)); + c->set_output(0, c->Scalar()); return Status::OK(); }) .Doc(R"doc( diff --git a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc index b644c23c7f..fbfed20748 100644 --- a/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc +++ b/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops_test.cc @@ -29,10 +29,10 @@ TEST(InputPipelineOpsTest, ObtainNext_InvalidNumberOfInputs) { TEST(InputPipelineOpsTest, ObtainNext) { ShapeInferenceTestOp op("ObtainNext"); - INFER_OK(op, "[100];[1]", "[1]"); + INFER_OK(op, "[100];[]", "[]"); - INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1];[1]"); - INFER_ERROR("Dimension must be 1 but is 2", op, "[1000];[2]"); + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "[1,1];[]"); + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[1000];[1]"); } } // end namespace tensorflow diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py index d8c95d9c39..edd6d22c5f 100644 --- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py +++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py @@ -17,11 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import random + from tensorflow.contrib.util import loader from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import resource_loader @@ -44,26 +45,60 @@ def obtain_next(string_list_tensor, counter): return _input_pipeline_ops.obtain_next(string_list_tensor, counter) -def seek_next(string_list): +def _maybe_randomize_list(string_list, shuffle): + if shuffle: + random.shuffle(string_list) + return string_list + + +def _create_list(string_list, shuffle, seed, num_epochs): + if shuffle and seed: + random.seed(seed) + expanded_list = _maybe_randomize_list(string_list, shuffle) + if num_epochs: + for _ in range(num_epochs - 1): + expanded_list.extend(_maybe_randomize_list(string_list, shuffle)) + return expanded_list + + +def seek_next(string_list, shuffle=False, seed=None, num_epochs=None): """Returns an op that seeks the next element in a list of strings. Seeking happens in a round robin fashion. This op creates a variable called counter that is initialized to -1 and is used to keep track of which element - in the list was returned. + in the list was returned. If num_epochs is not None, then we limit the number + of times we go around the string_list before OutOfRangeError is thrown. It + creates a variable to keep track of this. Args: - string_list: A list of strings + string_list: A list of strings. + shuffle: If true, we shuffle the string_list differently for each epoch. + seed: Seed used for shuffling. + num_epochs: Returns OutOfRangeError once string_list has been repeated + num_epoch times. If unspecified then keeps on looping. Returns: An op that produces the next element in the provided list. """ + expanded_list = _create_list(string_list, shuffle, seed, num_epochs) + with variable_scope.variable_scope("obtain_next"): counter = variable_scope.get_variable( name="obtain_next_counter", - initializer=constant_op.constant([-1], dtype=dtypes.int64), + initializer=constant_op.constant( + -1, dtype=dtypes.int64), dtype=dtypes.int64) with ops.device(counter.device): - string_tensor = constant_op.constant(string_list, - name="obtain_next_string_list") - return obtain_next(string_tensor, counter) - + string_tensor = constant_op.constant( + expanded_list, name="obtain_next_expanded_list") + if num_epochs: + filename_counter = variable_scope.get_variable( + name="obtain_next_filename_counter", + initializer=constant_op.constant( + 0, dtype=dtypes.int64), + dtype=dtypes.int64) + c = filename_counter.count_up_to(len(expanded_list)) + with ops.control_dependencies([c]): + return obtain_next(string_tensor, counter) + else: + return obtain_next(string_tensor, counter) diff --git a/tensorflow/contrib/input_pipeline/python/kernel_tests/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py index b8f1d9c0e4..00467b8cc9 100644 --- a/tensorflow/contrib/input_pipeline/python/kernel_tests/input_pipeline_ops_test.py +++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py @@ -20,37 +20,63 @@ from __future__ import print_function import tensorflow as tf from tensorflow.contrib.input_pipeline.python.ops import input_pipeline_ops +from tensorflow.python.framework import errors from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variables as var_ops class InputPipelineOpsTest(tf.test.TestCase): def testObtainNext(self): with self.test_session(): - var = state_ops.variable_op([1], tf.int64) - tf.assign(var, [-1]).op.run() + var = state_ops.variable_op([], tf.int64) + tf.assign(var, -1).op.run() c = tf.constant(["a", "b"]) sample1 = input_pipeline_ops.obtain_next(c, var) self.assertEqual(b"a", sample1.eval()) - self.assertEqual([0], var.eval()) + self.assertEqual(0, var.eval()) sample2 = input_pipeline_ops.obtain_next(c, var) self.assertEqual(b"b", sample2.eval()) - self.assertEqual([1], var.eval()) + self.assertEqual(1, var.eval()) sample3 = input_pipeline_ops.obtain_next(c, var) self.assertEqual(b"a", sample3.eval()) - self.assertEqual([0], var.eval()) + self.assertEqual(0, var.eval()) def testSeekNext(self): string_list = ["a", "b", "c"] with self.test_session() as session: elem = input_pipeline_ops.seek_next(string_list) - session.run(tf.initialize_all_variables()) + session.run([tf.global_variables_initializer()]) self.assertEqual(b"a", session.run(elem)) self.assertEqual(b"b", session.run(elem)) self.assertEqual(b"c", session.run(elem)) + # Make sure we loop. self.assertEqual(b"a", session.run(elem)) + # Helper method that runs the op len(expected_list) number of times, asserts + # that the results are elements of the expected_list and then throws an + # OutOfRangeError. + def _assert_output(self, expected_list, session, op): + for element in expected_list: + self.assertEqual(element, session.run(op)) + with self.assertRaises(errors.OutOfRangeError): + session.run(op) + + def testSeekNextLimitEpochs(self): + string_list = ["a", "b", "c"] + with self.test_session() as session: + elem = input_pipeline_ops.seek_next(string_list, num_epochs=1) + session.run( + [tf.local_variables_initializer(), tf.global_variables_initializer()]) + self._assert_output([b"a", b"b", b"c"], session, elem) + + def testSeekNextLimitEpochsTwo(self): + string_list = ["a", "b", "c"] + with self.test_session() as session: + elem = input_pipeline_ops.seek_next(string_list, num_epochs=2) + session.run( + [tf.local_variables_initializer(), tf.global_variables_initializer()]) + # Expect to see [a, b, c] two times. + self._assert_output([b"a", b"b", b"c"] * 2, session, elem) if __name__ == "__main__": tf.test.main() |