aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 16:46:54 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 16:53:12 -0800
commit6e99d56489b4e6c3176fa1199d4270b6439a22fe (patch)
tree18f430d5c93103ca49334a25c231c07b6371ad1a
parent7efc16ed02121b92993b3417805cea652bab3c92 (diff)
Add metadata for gathering information about host compute transfers while compiling XLA.
PiperOrigin-RevId: 188102740
-rw-r--r--tensorflow/compiler/tf2xla/BUILD10
-rw-r--r--tensorflow/compiler/tf2xla/host_compute_metadata.proto38
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc63
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h24
4 files changed, 135 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index fb82c2601c..eb20ca501c 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -58,6 +58,15 @@ xla_proto_library(
],
)
+xla_proto_library(
+ name = "host_compute_metadata_proto",
+ srcs = ["host_compute_metadata.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
cc_library(
name = "tf2xla",
srcs = ["tf2xla.cc"],
@@ -149,6 +158,7 @@ cc_library(
":common",
":dump_graph",
":functionalize_control_flow",
+ ":host_compute_metadata_proto",
":sharding_util",
":tf2xla_util",
"//tensorflow/compiler/tf2xla/lib:util",
diff --git a/tensorflow/compiler/tf2xla/host_compute_metadata.proto b/tensorflow/compiler/tf2xla/host_compute_metadata.proto
new file mode 100644
index 0000000000..43ab371a21
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/host_compute_metadata.proto
@@ -0,0 +1,38 @@
+syntax = "proto3";
+
+package tensorflow.tf2xla;
+option cc_enable_arenas = true;
+option java_outer_classname = "Tf2XlaProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.tf2xla";
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+// TensorMetadata indicates the type and shape of a Tensor that is
+// part of a host compute transfer.
+message TensorMetadata {
+ DataType type = 1;
+ TensorShapeProto shape = 2;
+}
+
+// HostTransferMetadata describes a transfer either from host to device
+// or device to host. It has a key that is unique to the computation,
+// and metadata about the list of tensors being transferred.
+message HostTransferMetadata {
+ // The key used to identify this transfer.
+ string key = 1;
+
+ // For each Tensor being transferred, its type and shape.
+ repeated TensorMetadata metadata = 2;
+}
+
+// HostComputeMetadata describes all the sends and recvs
+// from all host compute transfer ops in a computation.
+message HostComputeMetadata {
+ // Metadata about each device_to_host transfer
+ repeated HostTransferMetadata device_to_host = 1;
+
+ // Metadata about each host_to_device transfer
+ repeated HostTransferMetadata host_to_device = 2;
+}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 5ec05c4121..0dc5118c9c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -674,6 +674,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
+ // Copy the host transfer metadata to the result.
+ for (const auto& send : host_compute_sends_) {
+ *result->host_compute_metadata.add_device_to_host() = send.second;
+ }
+ for (const auto& recv : host_compute_recvs_) {
+ *result->host_compute_metadata.add_host_to_device() = recv.second;
+ }
+
// Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
@@ -708,4 +716,59 @@ Status XlaCompiler::GetChannelHandle(const string& key,
return Status::OK();
}
+namespace {
+
+void SetTransfer(const string& key, const std::vector<DataType>& types,
+ const std::vector<TensorShape>& shapes,
+ tf2xla::HostTransferMetadata* transfer) {
+ transfer->set_key(key);
+ CHECK(types.size() == shapes.size());
+ for (int i = 0; i < types.size(); ++i) {
+ tf2xla::TensorMetadata* metadata = transfer->add_metadata();
+ metadata->set_type(types[i]);
+ shapes[i].AsProto(metadata->mutable_shape());
+ }
+}
+
+} // namespace
+
+Status XlaCompiler::SetDeviceToHostMetadata(
+ const string& key, const std::vector<DataType>& types,
+ const std::vector<TensorShape>& shapes) {
+ if (host_compute_sends_.find(key) != host_compute_sends_.end()) {
+ return errors::InvalidArgument(
+ "Duplicate calls to SetDeviceToHostMetadata with key ", key);
+ }
+ tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key];
+ SetTransfer(key, types, shapes, &transfer);
+ return Status::OK();
+}
+
+Status XlaCompiler::GetDeviceToHostShapes(
+ const string& key, std::vector<TensorShape>* shapes) const {
+ const auto iter = host_compute_sends_.find(key);
+ if (iter == host_compute_sends_.end()) {
+ return errors::InvalidArgument(
+ "No host compute send shapes registered for key ", key);
+ }
+ shapes->clear();
+ for (int i = 0; i < iter->second.metadata_size(); ++i) {
+ TensorShape shape(iter->second.metadata(i).shape());
+ shapes->push_back(shape);
+ }
+ return Status::OK();
+}
+
+Status XlaCompiler::SetHostToDeviceMetadata(
+ const string& key, const std::vector<DataType>& types,
+ const std::vector<TensorShape>& shapes) {
+ if (host_compute_recvs_.find(key) != host_compute_sends_.end()) {
+ return errors::InvalidArgument(
+ "Duplicate calls to SetHostToDeviceMetadata with key ", key);
+ }
+ tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key];
+ SetTransfer(key, types, shapes, &transfer);
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index c4449bc4be..a70d2637e0 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -216,6 +217,10 @@ class XlaCompiler {
// containing both constant and non-constant results.
std::vector<OutputDescription> outputs;
+ // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
+ // matching RecvAtHost/SendFromHost Ops in the outer graph.
+ tf2xla::HostComputeMetadata host_compute_metadata;
+
// Resources whose values were updated by the computation, ordered
// by return value position. Resource updates follow the non-constant
// results in the outputs of XLA computation.
@@ -296,6 +301,22 @@ class XlaCompiler {
// same XlaCompiler.
Status GetChannelHandle(const string& key, xla::ChannelHandle* channel);
+ // Sets the shapes and types for the device to host transfer associated with
+ // 'key'.
+ Status SetDeviceToHostMetadata(const string& key,
+ const std::vector<DataType>& types,
+ const std::vector<TensorShape>& shapes);
+
+ // Gets the shapes the device to host transfer associated with 'key'.
+ Status GetDeviceToHostShapes(const string& key,
+ std::vector<TensorShape>* shapes) const;
+
+ // Sets the shapes and types for the host to device transfer associated with
+ // 'key'.
+ Status SetHostToDeviceMetadata(const string& key,
+ const std::vector<DataType>& types,
+ const std::vector<TensorShape>& shapes);
+
const Options& options() const { return options_; }
xla::Client* client() const { return options_.client; }
FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; }
@@ -359,6 +380,9 @@ class XlaCompiler {
std::unordered_map<string, xla::ChannelHandle> channels_;
+ std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_sends_;
+ std::unordered_map<string, tf2xla::HostTransferMetadata> host_compute_recvs_;
+
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};