aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-21 17:25:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-21 17:46:52 -0800
commit123c2bb0af532d5fdaa05358158da33497d4bfe6 (patch)
tree1294a19ced2c25e4f63f5c0baa1dc7a3dece7562 /tensorflow/tools/dist_test
parent7ed571e69de11161f3fad0fd91ccbd1ab47f7ad7 (diff)
Utils for starting, waiting for and stopping distributed benchmarks using
kubernetes. Change: 148166332
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/scripts/BUILD16
-rw-r--r--tensorflow/tools/dist_test/scripts/kubectl_util.py196
-rw-r--r--tensorflow/tools/dist_test/scripts/kubectl_util_test.py72
3 files changed, 284 insertions, 0 deletions
diff --git a/tensorflow/tools/dist_test/scripts/BUILD b/tensorflow/tools/dist_test/scripts/BUILD
index abcd521a26..9041a19af7 100644
--- a/tensorflow/tools/dist_test/scripts/BUILD
+++ b/tensorflow/tools/dist_test/scripts/BUILD
@@ -19,3 +19,19 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
+
+py_library(
+ name = "kubectl_util_lib",
+ srcs = ["kubectl_util.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "kubectl_util_test",
+ srcs = ["kubectl_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":kubectl_util_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/tools/dist_test/scripts/kubectl_util.py b/tensorflow/tools/dist_test/scripts/kubectl_util.py
new file mode 100644
index 0000000000..ae0c5601e9
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/kubectl_util.py
@@ -0,0 +1,196 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utils for running, waiting and stopping benchmark jobs on kubernetes.
+
+Functions in this file assume kubernetes jobs have 'name-prefix' and 'job'
+selectors set.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+import subprocess
+import time
+
+
+_KUBECTL = 'kubectl'
+WAIT_PERIOD_SECONDS = 20
+
+
+class TimeoutError(Exception):
+ pass
+
+
+def _WaitUntil(timeout, predicate, *args):
+ start_time = time.time()
+ while time.time() - start_time < timeout:
+ time.sleep(WAIT_PERIOD_SECONDS)
+ if predicate(*args):
+ return True
+ return False
+
+
+def _GetPodNames(pod_name_prefix, job_name=None):
+ """Get pod names based on the pod_name_prefix and job_name.
+
+ Args:
+ pod_name_prefix: value of 'name-prefix' selector.
+ job_name: value of 'job' selector. If None, pod names will be
+ selected only based on 'name-prefix' selector.
+
+ Returns:
+ List of pod names.
+ """
+ pod_list_command = [
+ _KUBECTL, 'get', 'pods', '-o', 'name',
+ '-l', _GetJobSelector(pod_name_prefix, job_name)]
+ logging.info('Command to get pod names: %s', ' '.join(pod_list_command))
+ output = subprocess.check_output(pod_list_command, universal_newlines=True)
+ pod_names = [name for name in output.strip().split('\n') if name]
+ logging.info('Pod names: "%s"', ','.join(pod_names))
+ return pod_names
+
+
+def CreatePods(pod_name, yaml_file):
+ """Creates pods based on the given kubernetes config.
+
+ Args:
+ pod_name: 'name-prefix' selector for the pods.
+ yaml_file: kubernetes yaml config.
+
+ Raises:
+ TimeoutError: if jobs didn't come up for a long time.
+ """
+ command = [_KUBECTL, 'create', '--filename=%s' % yaml_file]
+ logging.info('Creating pods: %s', subprocess.list2cmdline(command))
+ subprocess.check_call(command)
+
+ if not _WaitUntil(100, _GetPodNames, pod_name):
+ raise TimeoutError(
+ 'Timed out waiting for %s pod to come up.' % pod_name)
+
+
+def DeletePods(pod_name, yaml_file):
+ """Deletes pods based on the given kubernetes config.
+
+ Args:
+ pod_name: 'name-prefix' selector for the pods.
+ yaml_file: kubernetes yaml config.
+
+ Raises:
+ TimeoutError: if jobs didn't terminate for a long time.
+ """
+ command = [_KUBECTL, 'delete', '--filename=%s' % yaml_file]
+ logging.info('Deleting pods: %s', ' '.join(command))
+ subprocess.call(command)
+
+ def CheckPodsAreTerminated():
+ return not _GetPodNames(pod_name)
+ if not _WaitUntil(100, CheckPodsAreTerminated):
+ raise TimeoutError(
+ 'Timed out waiting for %s pod to terminate.' % pod_name)
+
+
+def _GetJobSelector(pod_name_prefix, job_name=None):
+ selector = 'name-prefix in (%s)' % pod_name_prefix
+ if job_name:
+ selector += ',job in (%s)' % job_name
+ return selector
+
+
+def WaitForCompletion(pod_name_prefix, job_name='worker', timeout=2*60*60):
+ """Waits until jobs matching pod_name and job_name are terminated.
+
+ Args:
+ pod_name_prefix: value of 'name-prefix' selector.
+ job_name: value of 'job' selector.
+ timeout: how long to wait for jobs to terminate before timing out.
+
+ Returns:
+ True if jobs terminated with success, False otherwise.
+
+ Raises:
+ TimeoutError: if jobs haven't terminated after timeout.
+ ValueError: if we couldn't find jobs matching pod_name and job_name.
+ """
+ # Jsonpath that selects comma-separated exit codes (followed by extra comma
+ # at the end).
+ # If a job doesn't have an exit code yet, empty string will be returned
+ # instead. For ex. output for 2 jobs where one is missing an exit code
+ # and the other one has an exit code of 0 would look like: ,0,
+ last_state_query = (
+ 'jsonpath=\'{range .items[*]}'
+ '{.status.containerStatuses[?(@.lastState.terminated)]'
+ '.lastState.terminated.exitCode},{end}\'')
+ status_command = [
+ _KUBECTL, 'get', '-o', last_state_query,
+ 'pods', '-l', _GetJobSelector(pod_name_prefix, job_name)
+ ]
+
+ exit_codes = []
+ start_time = time.time()
+ while time.time() - start_time < timeout:
+ # Output of check_output is a string that starts and ends with '.
+ output = subprocess.check_output(
+ status_command, universal_newlines=True).strip('\'')
+ logging.debug('Pod status: %s', output)
+ if not output:
+ raise ValueError(
+ 'Query did not match any data. Query: %s' % ' '.join(status_command))
+ # Output will end with an extra comma. So, we remove it before splitting.
+ exit_codes = output[:-1].split(',')
+ if '' not in exit_codes: # fetched all exit codes
+ break
+ time.sleep(WAIT_PERIOD_SECONDS)
+
+ if '' in exit_codes:
+ raise TimeoutError(
+ 'Timed out waiting for %s %s jobs to finish.' %
+ (pod_name_prefix, job_name))
+ _PrintLogs(pod_name_prefix, job_name)
+
+ failed_job_count = sum(code != '0' for code in exit_codes)
+ if failed_job_count > 0:
+ logging.error('%d out of %d jobs failed. Exit codes: %s',
+ failed_job_count, len(exit_codes), ','.join(exit_codes))
+ return False
+ return True
+
+
+def _PrintLogs(pod_name_prefix, job_name):
+ """Prints pod logs.
+
+ If a pod has been restarted, prints logs from previous run. Otherwise,
+ prints the logs from current run. We print logs for pods selected
+ based on pod_name_prefix and job_name.
+
+ Args:
+ pod_name_prefix: value of 'name-prefix' selector.
+ job_name: value of 'job' selector.
+ """
+ for pod_name in _GetPodNames(pod_name_prefix, job_name):
+ try:
+ # Get previous logs.
+ logs_command = [_KUBECTL, 'logs', '-p', pod_name]
+ logging.info('Command to get logs: %s', ' '.join(logs_command))
+ output = subprocess.check_output(logs_command, universal_newlines=True)
+ except subprocess.CalledProcessError:
+ # We couldn't get previous logs, so we will try to get current logs.
+ logs_command = [_KUBECTL, 'logs', pod_name]
+ logging.info('Command to get logs: %s', ' '.join(logs_command))
+ output = subprocess.check_output(logs_command, universal_newlines=True)
+ print('%s logs:' % pod_name)
+ print(output)
diff --git a/tensorflow/tools/dist_test/scripts/kubectl_util_test.py b/tensorflow/tools/dist_test/scripts/kubectl_util_test.py
new file mode 100644
index 0000000000..e2ddf79202
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/kubectl_util_test.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.tools.dist_test.scripts.kubectl_util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import subprocess
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+from tensorflow.tools.dist_test.scripts import kubectl_util
+
+
+kubectl_util.WAIT_PERIOD_SECONDS = 1
+
+
+class KubectlUtilTest(googletest.TestCase):
+
+ @test.mock.patch.object(subprocess, 'check_output')
+ @test.mock.patch.object(subprocess, 'check_call')
+ def testCreatePods(self, mock_check_call, mock_check_output):
+ mock_check_output.return_value = 'nonempty'
+ kubectl_util.CreatePods('test_pod', 'test.yaml')
+ mock_check_call.assert_called_once_with(
+ ['kubectl', 'create', '--filename=test.yaml'])
+ mock_check_output.assert_called_once_with(
+ ['kubectl', 'get', 'pods', '-o', 'name', '-l',
+ 'name-prefix in (test_pod)'], universal_newlines=True)
+
+ @test.mock.patch.object(subprocess, 'check_output')
+ @test.mock.patch.object(subprocess, 'call')
+ def testDeletePods(self, mock_check_call, mock_check_output):
+ mock_check_output.return_value = ''
+ kubectl_util.DeletePods('test_pod', 'test.yaml')
+ mock_check_call.assert_called_once_with(
+ ['kubectl', 'delete', '--filename=test.yaml'])
+ mock_check_output.assert_called_once_with(
+ ['kubectl', 'get', 'pods', '-o', 'name', '-l',
+ 'name-prefix in (test_pod)'], universal_newlines=True)
+
+ @test.mock.patch.object(subprocess, 'check_output')
+ def testWaitForCompletion(self, mock_check_output):
+ # Test success
+ mock_check_output.return_value = '\'0,0,\''
+ self.assertTrue(kubectl_util.WaitForCompletion('test_pod'))
+
+ # Test failure
+ mock_check_output.return_value = '\'0,1,\''
+ self.assertFalse(kubectl_util.WaitForCompletion('test_pod'))
+
+ # Test timeout
+ with self.assertRaises(kubectl_util.TimeoutError):
+ mock_check_output.return_value = '\'0,,\''
+ kubectl_util.WaitForCompletion('test_pod', timeout=5)
+
+
+if __name__ == '__main__':
+ googletest.main()