aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eric Liu <ioeric@google.com>2017-07-17 11:57:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 12:02:27 -0700
commit09e9b157787f2d03ae864569de33e62424657450 (patch)
tree699ca71ae84925c7e609a50b1432e70dc072def5
parent139a866a82fd232993cb12bb9d054ba5f9ade5ac (diff)
Add a gRPC client for profiling TPU (contrib/tpu/profiler/)
This contains a gRPC client that starts/stops tracing and processes/stores the result trace data into a TensorBoard log directory. This also exposes trace_events proto classes via tf.contrib.tpu.profiler public API so that TensorBoard's profile plugin can process and visualize the profile. PiperOrigin-RevId: 162247333
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake3
-rw-r--r--tensorflow/contrib/tpu/BUILD11
-rw-r--r--tensorflow/contrib/tpu/__init__.py6
-rw-r--r--tensorflow/contrib/tpu/profiler/BUILD36
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc105
-rw-r--r--tensorflow/contrib/tpu/profiler/tpu_profiler.proto32
-rw-r--r--tensorflow/contrib/tpu/profiler/trace_events.proto59
-rw-r--r--tensorflow/contrib/tpu/python/profiler/__init__.py30
8 files changed, 281 insertions, 1 deletions
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 75709df55b..8701a584e8 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -128,6 +128,7 @@ file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/proto/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto"
+ "${tensorflow_source_dir}/tensorflow/contrib/tpu/profiler/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/training/*.proto"
)
RELATIVE_PROTOBUF_GENERATE_PYTHON(
@@ -541,8 +542,10 @@ add_python_module("tensorflow/contrib/timeseries/python/timeseries")
add_python_module("tensorflow/contrib/timeseries/python/timeseries/state_space_models")
add_python_module("tensorflow/contrib/tpu")
add_python_module("tensorflow/contrib/tpu/ops")
+add_python_module("tensorflow/contrib/tpu/profiler")
add_python_module("tensorflow/contrib/tpu/python")
add_python_module("tensorflow/contrib/tpu/python/ops")
+add_python_module("tensorflow/contrib/tpu/python/profiler")
add_python_module("tensorflow/contrib/tpu/python/tpu")
add_python_module("tensorflow/contrib/training")
add_python_module("tensorflow/contrib/training/python")
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 210bfc766e..c99a4d0475 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -93,6 +93,15 @@ tf_gen_op_wrapper_py(
],
)
+py_library(
+ name = "profiler",
+ srcs = ["python/profiler/__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/tpu/profiler:trace_events_proto_py",
+ ],
+)
+
tf_custom_op_py_library(
name = "tpu_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
@@ -102,6 +111,7 @@ tf_custom_op_py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":profiler",
":tpu_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:client_testlib",
@@ -146,6 +156,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":profiler",
":tpu_function",
":tpu_py",
":training_loop",
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index bfd7887c51..1abd55b56d 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -20,9 +20,13 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
+from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
from tensorflow.contrib.tpu.python.tpu import *
# pylint: enable=wildcard-import,unused-import
from tensorflow.python.util.all_util import remove_undocumented
-remove_undocumented(__name__)
+
+_allowed_symbols = ['profiler']
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD
new file mode 100644
index 0000000000..b806a94b1b
--- /dev/null
+++ b/tensorflow/contrib/tpu/profiler/BUILD
@@ -0,0 +1,36 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_cc")
+
+tf_proto_library_cc(
+ name = "tpu_profiler_proto",
+ srcs = ["tpu_profiler.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ cc_grpc_version = 1,
+ protodeps = [
+ "//tensorflow/core:protos_all",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_binary(
+ name = "capture_tpu_profile",
+ srcs = ["capture_tpu_profile.cc"],
+ visibility = ["//tensorflow/contrib/tpu/profiler:__subpackages__"],
+ deps = [
+ ":tpu_profiler_proto_cc",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+ "@grpc//:grpc++_unsecure",
+ ],
+)
+
+tf_proto_library(
+ name = "trace_events_proto",
+ srcs = ["trace_events.proto"],
+ cc_api_version = 2,
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
new file mode 100644
index 0000000000..b575128159
--- /dev/null
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: capture_tpu_profile --service_addr="localhost:8466" --logdir=/tmp/log
+//
+// Initiates a TPU profiling on the TPUProfiler service at service_addr,
+// receives and dumps the profile data to a tensorboard log directory.
+
+#include "grpc++/grpc++.h"
+
+#include <cstdio>
+#include <ctime>
+#include <vector>
+
+#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace tpu {
+namespace {
+
+using ::tensorflow::TPUProfiler;
+
+using ::grpc::ClientContext;
+using ::tensorflow::io::JoinPath;
+using ::tensorflow::Env;
+using ::tensorflow::WriteStringToFile;
+
+constexpr char kProfilePluginDirectory[] = "plugins/profile/";
+constexpr char kTraceFileName[] = "trace";
+
+tensorflow::string GetCurrentTimeStampAsString() {
+ char s[128];
+ std::time_t t = std::time(nullptr);
+ CHECK_NE(std::strftime(s, sizeof(s), "%F_%T", std::localtime(&t)), 0);
+ return s;
+}
+
+// The trace will be stored in <logdir>/plugins/profile/<timestamp>/trace.
+void DumpTraceToLogDirectory(const tensorflow::string& logdir,
+ tensorflow::StringPiece trace) {
+ tensorflow::string run = GetCurrentTimeStampAsString();
+ tensorflow::string run_dir = JoinPath(logdir, kProfilePluginDirectory, run);
+ TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(run_dir));
+ tensorflow::string path = JoinPath(run_dir, kTraceFileName);
+ TF_CHECK_OK(WriteStringToFile(tensorflow::Env::Default(), path, trace));
+ LOG(INFO) << "Dumped trace data to " << path;
+}
+
+ProfileResponse Profile(const tensorflow::string& service_addr) {
+ ProfileRequest request;
+ ProfileResponse response;
+ ClientContext context;
+ std::unique_ptr<TPUProfiler::Stub> stub =
+ TPUProfiler::NewStub(::grpc::CreateChannel(
+ service_addr, ::grpc::InsecureChannelCredentials()));
+ TF_CHECK_OK(FromGrpcStatus(stub->Profile(&context, request, &response)));
+ return response;
+}
+
+} // namespace
+} // namespace tpu
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ tensorflow::string FLAGS_service_addr;
+ tensorflow::string FLAGS_logdir;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("service_addr", &FLAGS_service_addr,
+ "Address of TPU profiler service e.g. localhost:8466"),
+ tensorflow::Flag("logdir", &FLAGS_logdir,
+ "Path of TensorBoard log directory e.g. /tmp/tb_log"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_ok || FLAGS_service_addr.empty() || FLAGS_logdir.empty()) {
+ std::printf("%s", usage.c_str());
+ return 2;
+ }
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ tensorflow::ProfileResponse response =
+ tensorflow::tpu::Profile(FLAGS_service_addr);
+ // Ignore computation_graph for now.
+ tensorflow::tpu::DumpTraceToLogDirectory(FLAGS_logdir,
+ response.encoded_trace());
+}
diff --git a/tensorflow/contrib/tpu/profiler/tpu_profiler.proto b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
new file mode 100644
index 0000000000..d12eccbf4b
--- /dev/null
+++ b/tensorflow/contrib/tpu/profiler/tpu_profiler.proto
@@ -0,0 +1,32 @@
+syntax = "proto3";
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+
+// The TPUProfiler service retrieves performance information about
+// the programs running on connected TPUs over a period of time.
+service TPUProfiler {
+ // Starts a profiling session, blocks until it completes, and returns data.
+ rpc Profile(ProfileRequest) returns (ProfileResponse) {
+ }
+}
+
+message ProfileRequest {
+ // In future, the caller will be able to customize when profiling starts and
+ // stops. For now, it always collects 10 seconds worth of data.
+
+ // In future, the caller will indicate which TF session is being profiled, and
+ // only data relating to that program will be returned. For now, we assume
+ // all activity during the profiling period is relevant.
+}
+
+message ProfileResponse {
+ uint64 xprof_response_size = 1; // Placeholder: return something meaningful.
+ // Graphs of programs executed on TPUs during the profiling period.
+ repeated GraphDef computation_graph = 2;
+
+ // Encoded Trace proto message that contains metadata about the trace captured
+ // during the profiling period. Describes the devices and resources that
+ // 'trace_events' refers to.
+ bytes encoded_trace = 3;
+}
diff --git a/tensorflow/contrib/tpu/profiler/trace_events.proto b/tensorflow/contrib/tpu/profiler/trace_events.proto
new file mode 100644
index 0000000000..0ab553ca96
--- /dev/null
+++ b/tensorflow/contrib/tpu/profiler/trace_events.proto
@@ -0,0 +1,59 @@
+syntax = "proto3";
+
+package tensorflow.tpu;
+
+// A 'Trace' contains metadata for the individual traces of a system.
+message Trace {
+ // The devices that this trace has information about. Maps from device_id to
+ // more data about the specific device.
+ map<uint64, Device> devices = 1;
+
+ // All trace events capturing in the profiling period.
+ repeated TraceEvent trace_events = 4;
+}
+
+// A 'device' is a physical entity in the system and is comprised of several
+// resources.
+message Device {
+ // The name of the device.
+ string name = 1;
+
+ // The id of this device, unique in a single trace.
+ uint64 device_id = 2;
+
+ // The resources on this device, keyed by resource_id;
+ map<uint64, Resource> resources = 3;
+}
+
+// A 'resource' generally is a specific computation component on a device. These
+// can range from threads on CPUs to specific arithmetic units on hardware
+// devices.
+message Resource {
+ // The name of the resource.
+ string name = 1;
+
+ // The id of the resource. Unique within a device.
+ uint64 resource_id = 2;
+}
+
+message TraceEvent {
+ // The id of the device that this event occurred on. The full dataset should
+ // have this device present in the Trace object.
+ uint64 device_id = 1;
+
+ // The id of the resource that this event occurred on. The full dataset should
+ // have this resource present in the Device object of the Trace object. A
+ // resource_id is unique on a specific device, but not necessarily within the
+ // trace.
+ uint64 resource_id = 2;
+
+ // The name of this trace event.
+ string name = 3;
+
+ // The timestamp that this event occurred at (in picos since tracing started).
+ uint64 timestamp_ps = 9;
+
+ // The duration of the event in picoseconds if applicable.
+ // Events without duration are called instant events.
+ uint64 duration_ps = 10;
+}
diff --git a/tensorflow/contrib/tpu/python/profiler/__init__.py b/tensorflow/contrib/tpu/python/profiler/__init__.py
new file mode 100644
index 0000000000..bde13f0527
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/profiler/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Classes for TPU trace events."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import,unused-import
+from tensorflow.contrib.tpu.profiler.trace_events_pb2 import *
+# pylint: enable=wildcard-import,unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = ['Trace', 'Resource', 'Device', 'TraceEvent']
+
+remove_undocumented(__name__, _allowed_symbols)