aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/input_pipeline
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2016-12-29 22:46:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-29 23:06:59 -0800
commite121667dc609de978a223c56ee906368d2c4ceef (patch)
tree7d4e1f1e1b4fd469487872c0cd34ddace5ac570c /tensorflow/contrib/input_pipeline
parent7815fcba7767aa1eb3196c5861e174f8b3c43bab (diff)
Remove so many more hourglass imports
Change: 143230429
Diffstat (limited to 'tensorflow/contrib/input_pipeline')
-rw-r--r--tensorflow/contrib/input_pipeline/BUILD2
-rw-r--r--tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py31
2 files changed, 21 insertions, 12 deletions
diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD
index 0c2e065e60..8eb8201f08 100644
--- a/tensorflow/contrib/input_pipeline/BUILD
+++ b/tensorflow/contrib/input_pipeline/BUILD
@@ -58,12 +58,14 @@ py_library(
deps = [
":input_pipeline_ops",
"//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
],
)
diff --git a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
index 00467b8cc9..d6c0bd62de 100644
--- a/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
+++ b/tensorflow/contrib/input_pipeline/python/ops/input_pipeline_ops_test.py
@@ -17,20 +17,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
-
from tensorflow.contrib.input_pipeline.python.ops import input_pipeline_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
-class InputPipelineOpsTest(tf.test.TestCase):
+class InputPipelineOpsTest(test.TestCase):
def testObtainNext(self):
with self.test_session():
- var = state_ops.variable_op([], tf.int64)
- tf.assign(var, -1).op.run()
- c = tf.constant(["a", "b"])
+ var = state_ops.variable_op([], dtypes.int64)
+ state_ops.assign(var, -1).op.run()
+ c = constant_op.constant(["a", "b"])
sample1 = input_pipeline_ops.obtain_next(c, var)
self.assertEqual(b"a", sample1.eval())
self.assertEqual(0, var.eval())
@@ -45,7 +47,7 @@ class InputPipelineOpsTest(tf.test.TestCase):
string_list = ["a", "b", "c"]
with self.test_session() as session:
elem = input_pipeline_ops.seek_next(string_list)
- session.run([tf.global_variables_initializer()])
+ session.run([variables.global_variables_initializer()])
self.assertEqual(b"a", session.run(elem))
self.assertEqual(b"b", session.run(elem))
self.assertEqual(b"c", session.run(elem))
@@ -65,18 +67,23 @@ class InputPipelineOpsTest(tf.test.TestCase):
string_list = ["a", "b", "c"]
with self.test_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=1)
- session.run(
- [tf.local_variables_initializer(), tf.global_variables_initializer()])
+ session.run([
+ variables.local_variables_initializer(),
+ variables.global_variables_initializer()
+ ])
self._assert_output([b"a", b"b", b"c"], session, elem)
def testSeekNextLimitEpochsTwo(self):
string_list = ["a", "b", "c"]
with self.test_session() as session:
elem = input_pipeline_ops.seek_next(string_list, num_epochs=2)
- session.run(
- [tf.local_variables_initializer(), tf.global_variables_initializer()])
+ session.run([
+ variables.local_variables_initializer(),
+ variables.global_variables_initializer()
+ ])
# Expect to see [a, b, c] two times.
self._assert_output([b"a", b"b", b"c"] * 2, session, elem)
+
if __name__ == "__main__":
- tf.test.main()
+ test.main()