diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-08-02 12:17:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-02 12:22:03 -0700 |
commit | b12f1df7c53500316e1d26674cf2cacf89df9c70 (patch) | |
tree | b1157b9c83c49eea3a91807189ea3de87d1faf72 /tensorflow/contrib/input_pipeline | |
parent | 6f786ddc04c7931fadf81c35d8f76b4c0f71b815 (diff) |
Fix a bug in seek_next due to which seeks were performed 2^(num_epochs - 1)
times instead of num_epochs times.
expanded_list and string_list were copies by reference which mean the list was
being doubled with each call to extend().
Also don't add the variables created to the TRAINABLE variables collection.
PiperOrigin-RevId: 164020330
Diffstat (limited to 'tensorflow/contrib/input_pipeline')
-rw-r--r-- | tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py | 8 |
2 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py index be22b5c579..0092ae74b2 100644 --- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py +++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops.py @@ -55,7 +55,7 @@ def _maybe_randomize_list(string_list, shuffle): def _create_list(string_list, shuffle, seed, num_epochs): if shuffle and seed: random.seed(seed) - expanded_list = _maybe_randomize_list(string_list, shuffle) + expanded_list = _maybe_randomize_list(string_list, shuffle)[:] if num_epochs: for _ in range(num_epochs - 1): expanded_list.extend(_maybe_randomize_list(string_list, shuffle)) @@ -89,18 +89,21 @@ def seek_next(string_list, shuffle=False, seed=None, num_epochs=None): name="obtain_next_counter", initializer=constant_op.constant( -1, dtype=dtypes.int64), - dtype=dtypes.int64) + dtype=dtypes.int64, + trainable=False) with ops.colocate_with(counter): string_tensor = variable_scope.get_variable( name="obtain_next_expanded_list", initializer=constant_op.constant(expanded_list), - dtype=dtypes.string) + dtype=dtypes.string, + trainable=False) if num_epochs: filename_counter = variable_scope.get_variable( name="obtain_next_filename_counter", initializer=constant_op.constant( 0, dtype=dtypes.int64), - dtype=dtypes.int64) + dtype=dtypes.int64, + trainable=False) c = filename_counter.count_up_to(len(expanded_list)) with ops.control_dependencies([c]): return obtain_next(string_tensor, counter) diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py index d6c0bd62de..9ed017592a 100644 --- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py +++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py @@ -73,16 +73,16 @@ class InputPipelineOpsTest(test.TestCase): ]) self._assert_output([b"a", b"b", b"c"], session, elem) - def testSeekNextLimitEpochsTwo(self): + def testSeekNextLimitEpochsThree(self): string_list = ["a", "b", "c"] with self.test_session() as session: - elem = input_pipeline_ops.seek_next(string_list, num_epochs=2) + elem = input_pipeline_ops.seek_next(string_list, num_epochs=3) session.run([ variables.local_variables_initializer(), variables.global_variables_initializer() ]) - # Expect to see [a, b, c] two times. - self._assert_output([b"a", b"b", b"c"] * 2, session, elem) + # Expect to see [a, b, c] three times. + self._assert_output([b"a", b"b", b"c"] * 3, session, elem) if __name__ == "__main__": |