aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2016-12-29 22:46:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-29 23:06:59 -0800
commite121667dc609de978a223c56ee906368d2c4ceef (patch)
tree7d4e1f1e1b4fd469487872c0cd34ddace5ac570c /tensorflow/tools/dist_test
parent7815fcba7767aa1eb3196c5861e174f8b3c43bab (diff)
Remove so many more hourglass imports
Change: 143230429
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/server/BUILD7
-rwxr-xr-xtensorflow/tools/dist_test/server/grpc_tensorflow_server.py25
-rw-r--r--tensorflow/tools/dist_test/server/parse_cluster_spec_test.py20
3 files changed, 25 insertions, 27 deletions
diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD
index 19f52f8208..25efc83716 100644
--- a/tensorflow/tools/dist_test/server/BUILD
+++ b/tensorflow/tools/dist_test/server/BUILD
@@ -15,7 +15,9 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow:tensorflow_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
],
)
@@ -29,7 +31,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":grpc_tensorflow_server",
- "//tensorflow:tensorflow_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
index 58931e8b2a..5e36eaf748 100755
--- a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
+++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Python-based TensorFlow GRPC server.
Takes input arguments cluster_spec, job_name and task_id, and start a blocking
@@ -30,27 +29,27 @@ Where:
PORT is a port number
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
-
+from tensorflow.core.protobuf import tensorflow_server_pb2
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags
+from tensorflow.python.training import server_lib
-FLAGS = tf.app.flags.FLAGS
+FLAGS = flags.FLAGS
-tf.app.flags.DEFINE_string("cluster_spec", "",
- """Cluster spec: SPEC.
+flags.DEFINE_string("cluster_spec", "", """Cluster spec: SPEC.
SPEC is <JOB>(,<JOB>)*,"
JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*,"
NAME is a valid job name ([a-z][0-9a-z]*),"
HOST is a hostname or IP address,"
PORT is a port number."
E.g., local|localhost:2222;localhost:2223, ps|ps0:2222;ps1:2222""")
-tf.app.flags.DEFINE_string("job_name", "", "Job name: e.g., local")
-tf.app.flags.DEFINE_integer("task_id", 0, "Task index, e.g., 0")
-tf.app.flags.DEFINE_boolean("verbose", False, "Verbose mode")
+flags.DEFINE_string("job_name", "", "Job name: e.g., local")
+flags.DEFINE_integer("task_id", 0, "Task index, e.g., 0")
+flags.DEFINE_boolean("verbose", False, "Verbose mode")
def parse_cluster_spec(cluster_spec, cluster):
@@ -99,7 +98,7 @@ def parse_cluster_spec(cluster_spec, cluster):
def main(unused_args):
# Create Protobuf ServerDef
- server_def = tf.train.ServerDef(protocol="grpc")
+ server_def = tensorflow_server_pb2.ServerDef(protocol="grpc")
# Cluster info
parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster)
@@ -115,11 +114,11 @@ def main(unused_args):
server_def.task_index = FLAGS.task_id
# Create GRPC Server instance
- server = tf.train.Server(server_def)
+ server = server_lib.Server(server_def)
# join() is blocking, unlike start()
server.join()
if __name__ == "__main__":
- tf.app.run()
+ app.run()
diff --git a/tensorflow/tools/dist_test/server/parse_cluster_spec_test.py b/tensorflow/tools/dist_test/server/parse_cluster_spec_test.py
index 0497f827fa..28b786ce2c 100644
--- a/tensorflow/tools/dist_test/server/parse_cluster_spec_test.py
+++ b/tensorflow/tools/dist_test/server/parse_cluster_spec_test.py
@@ -12,21 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Tests for cluster-spec string parser in GRPC TensorFlow server."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import tensorflow as tf
-
+from tensorflow.core.protobuf import tensorflow_server_pb2
+from tensorflow.python.platform import test
from tensorflow.tools.dist_test.server import grpc_tensorflow_server
-class ParseClusterSpecStringTest(tf.test.TestCase):
+class ParseClusterSpecStringTest(test.TestCase):
def setUp(self):
- self._cluster = tf.train.ServerDef(protocol="grpc").cluster
+ self._cluster = tensorflow_server_pb2.ServerDef(protocol="grpc").cluster
def test_parse_multi_jobs_sunnyday(self):
cluster_spec = ("worker|worker0:2220;worker1:2221;worker2:2222,"
@@ -50,8 +49,7 @@ class ParseClusterSpecStringTest(tf.test.TestCase):
def test_empty_cluster_spec_string(self):
cluster_spec = ""
- with self.assertRaisesRegexp(ValueError,
- "Empty cluster_spec string"):
+ with self.assertRaisesRegexp(ValueError, "Empty cluster_spec string"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
def test_parse_misused_comma_for_semicolon(self):
@@ -71,18 +69,16 @@ class ParseClusterSpecStringTest(tf.test.TestCase):
def test_parse_empty_job_name(self):
cluster_spec = "worker|worker0:2220,|ps0:3220"
- with self.assertRaisesRegexp(ValueError,
- "Empty job_name in cluster_spec"):
+ with self.assertRaisesRegexp(ValueError, "Empty job_name in cluster_spec"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
print(self._cluster)
def test_parse_empty_task(self):
cluster_spec = "worker|worker0:2220,ps|"
- with self.assertRaisesRegexp(ValueError,
- "Empty task string at position 0"):
+ with self.assertRaisesRegexp(ValueError, "Empty task string at position 0"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
if __name__ == "__main__":
- tf.test.main()
+ test.main()