aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/input_pipeline
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-02 12:17:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-02 12:22:03 -0700
commitb12f1df7c53500316e1d26674cf2cacf89df9c70 (patch)
treeb1157b9c83c49eea3a91807189ea3de87d1faf72 /tensorflow/contrib/input_pipeline
parent6f786ddc04c7931fadf81c35d8f76b4c0f71b815 (diff)
Fix a bug in seek_next due to which seeks were performed 2^(num_epochs - 1)
times instead of num_epochs times. expanded_list and string_list were copies by reference which mean the list was being doubled with each call to extend(). Also don't add the variables created to the TRAINABLE variables collection. PiperOrigin-RevId: 164020330
Diffstat (limited to 'tensorflow/contrib/input_pipeline')
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py11
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py8
2 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py
index be22b5c579..0092ae74b2 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py
@@ -55,7 +55,7 @@ def _maybe_randomize_list(string_list, shuffle):
def _create_list(string_list, shuffle, seed, num_epochs):
if shuffle and seed:
random.seed(seed)
- expanded_list = _maybe_randomize_list(string_list, shuffle)
+ expanded_list = _maybe_randomize_list(string_list, shuffle)[:]
if num_epochs:
for _ in range(num_epochs - 1):
expanded_list.extend(_maybe_randomize_list(string_list, shuffle))
@@ -89,18 +89,21 @@ def seek_next(string_list, shuffle=False, seed=None, num_epochs=None):
name="obtain_next_counter",
initializer=constant_op.constant(
-1, dtype=dtypes.int64),
- dtype=dtypes.int64)
+ dtype=dtypes.int64,
+ trainable=False)
with ops.colocate_with(counter):
string_tensor = variable_scope.get_variable(
name="obtain_next_expanded_list",
initializer=constant_op.constant(expanded_list),
- dtype=dtypes.string)
+ dtype=dtypes.string,
+ trainable=False)
if num_epochs:
filename_counter = variable_scope.get_variable(
name="obtain_next_filename_counter",
initializer=constant_op.constant(
0, dtype=dtypes.int64),
- dtype=dtypes.int64)
+ dtype=dtypes.int64,
+ trainable=False)
c = filename_counter.count_up_to(len(expanded_list))
with ops.control_dependencies([c]):
return obtain_next(string_tensor, counter)
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index d6c0bd62de..9ed017592a 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -73,16 +73,16 @@ class InputPipelineOpsTest(test.TestCase):
])
self._assert_output([b"a", b"b", b"c"], session, elem)
- def testSeekNextLimitEpochsTwo(self):
+ def testSeekNextLimitEpochsThree(self):
string_list = ["a", "b", "c"]
with self.test_session() as session:
- elem = input_pipeline_ops.seek_next(string_list, num_epochs=2)
+ elem = input_pipeline_ops.seek_next(string_list, num_epochs=3)
session.run([
variables.local_variables_initializer(),
variables.global_variables_initializer()
])
- # Expect to see [a, b, c] two times.
- self._assert_output([b"a", b"b", b"c"] * 2, session, elem)
+ # Expect to see [a, b, c] three times.
+ self._assert_output([b"a", b"b", b"c"] * 3, session, elem)
if __name__ == "__main__":