aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/distributed_runtime/BUILD23
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.cc221
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.h85
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding_test.cc186
-rw-r--r--tensorflow/core/platform/default/protobuf.h1
5 files changed, 515 insertions, 1 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 51bb908213..244ba8bb7e 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -51,9 +51,14 @@ cc_library(
cc_library(
name = "worker_interface",
- hdrs = ["worker_interface.h"],
+ srcs = ["tensor_coding.cc"],
+ hdrs = [
+ "tensor_coding.h",
+ "worker_interface.h",
+ ],
deps = [
":call_options",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:worker_proto_cc",
],
@@ -79,6 +84,22 @@ cc_test(
],
)
+cc_test(
+ name = "tensor_coding_test",
+ size = "small",
+ srcs = ["tensor_coding_test.cc"],
+ linkstatic = 1,
+ deps = [
+ ":worker_interface",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
cc_library(
name = "worker_cache",
hdrs = ["worker_cache.h"],
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc
new file mode 100644
index 0000000000..72399c9b11
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/tensor_coding.cc
@@ -0,0 +1,221 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/tensor_coding.h"
+
+namespace tensorflow {
+
+TensorResponse::TensorResponse(Allocator* allocator) : allocator_(allocator) {}
+
+Status TensorResponse::ParseFrom(Source* source) {
+ if (already_used_) {
+ Clear();
+ }
+ already_used_ = true;
+ if (ParseFast(source)) return Status::OK();
+ meta_.Clear();
+ if (ParseSlow(source)) return Status::OK();
+ return errors::InvalidArgument("Cannot parse tensor from response");
+}
+
+// Define some helper routines for decoding protocol buffer wire format data
+namespace {
+// We only need some of the wiretype values for this code
+enum WireType {
+ WIRETYPE_VARINT = 0,
+ WIRETYPE_LENGTH_DELIMITED = 2,
+};
+inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; }
+inline WireType GetTagWireType(uint32 tag) {
+ return static_cast<WireType>(tag & 0x7);
+}
+
+bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) {
+ uint64 v;
+ if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) {
+ *result = static_cast<int>(v);
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool ReadNestedMessage(protobuf::io::CodedInputStream* input,
+ protobuf::Message* value) {
+ int length;
+ if (!ReadVarintSizeAsInt(input, &length)) return false;
+ std::pair<protobuf::io::CodedInputStream::Limit, int> p =
+ input->IncrementRecursionDepthAndPushLimit(length);
+ if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false;
+ // Make sure that parsing stopped when the limit was hit, not at an endgroup
+ // tag.
+ return input->DecrementRecursionDepthAndPopLimit(p.first);
+}
+
+} // namespace
+
+bool TensorResponse::ParseTensorSubmessage(
+ protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) {
+ bool seen_tensor_content = false;
+ while (true) {
+ auto p = input->ReadTagWithCutoff(127);
+ int tag = GetTagFieldNumber(p.first);
+ WireType wt = GetTagWireType(p.first);
+ if (!p.second) {
+ bool ok = (tag == 0);
+ if (ok && !seen_tensor_content) {
+ // No tensor content: could be because it's a zero-length tensor
+ TensorShape shape(tensor_meta->tensor_shape());
+ Tensor t(allocator_, tensor_meta->dtype(), shape);
+ tensor_ = std::move(t);
+ }
+ return ok;
+ }
+ switch (tag) {
+ case TensorProto::kDtypeFieldNumber: {
+ uint32 v;
+ if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
+ if (seen_tensor_content) return false;
+ tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v)));
+ if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false;
+ break;
+ }
+ case TensorProto::kTensorShapeFieldNumber: {
+ if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
+ !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape()))
+ return false;
+ if (seen_tensor_content) return false;
+ break;
+ }
+ case TensorProto::kVersionNumberFieldNumber: {
+ uint32 v;
+ if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
+ if (seen_tensor_content) return false;
+ tensor_meta->set_version_number(static_cast<int32>(v));
+ break;
+ }
+ case TensorProto::kTensorContentFieldNumber: {
+ // If we haven't seen the dtype and tensor_shape data first, we can't
+ // deal with this in the fast path.
+ if (seen_tensor_content) return false;
+ if (wt != WIRETYPE_LENGTH_DELIMITED ||
+ !tensor_meta->has_tensor_shape()) {
+ return false;
+ }
+ int num_bytes;
+ if (!ReadVarintSizeAsInt(input, &num_bytes)) return false;
+ seen_tensor_content = true;
+ TensorShape shape(tensor_meta->tensor_shape());
+ Tensor t(allocator_, tensor_meta->dtype(), shape);
+ StringPiece buf = t.tensor_data();
+ if (num_bytes != buf.size()) return false;
+ // TODO(jeff,sanjay): Figure out a way to avoid this copy if
+ // the underlying ZeroCopyInputStream data is properly aligned
+ // and compatible with what allocator_ wants.
+ if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes))
+ return false;
+ tensor_ = std::move(t);
+ break;
+ }
+ default: {
+ // Some other tag our fast path code is not prepared to handle.
+ // return false.
+ return false;
+ }
+ }
+ }
+}
+
+bool TensorResponse::ParseFast(Source* source) {
+ protobuf::io::CodedInputStream input(source->contents());
+ while (true) {
+ auto p = input.ReadTagWithCutoff(127);
+ int tag = GetTagFieldNumber(p.first);
+ WireType wt = GetTagWireType(p.first);
+ if (!p.second) {
+ return (tag == 0);
+ }
+ switch (tag) {
+ case RecvTensorResponse::kTensorFieldNumber: {
+ if (wt != WIRETYPE_LENGTH_DELIMITED) return false;
+
+ int length;
+ if (!ReadVarintSizeAsInt(&input, &length)) return false;
+ std::pair<protobuf::io::CodedInputStream::Limit, int> p =
+ input.IncrementRecursionDepthAndPushLimit(length);
+ if (p.second < 0 ||
+ !ParseTensorSubmessage(&input, meta_.mutable_tensor())) {
+ return false;
+ }
+ if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
+ return false;
+ }
+ break;
+ }
+ case RecvTensorResponse::kIsDeadFieldNumber: {
+ uint32 v;
+ if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
+ meta_.set_is_dead((v != 0) ? true : false);
+ break;
+ }
+ case RecvTensorResponse::kSendStartMicrosFieldNumber: {
+ uint64 v;
+ if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false;
+ meta_.set_send_start_micros(static_cast<int64>(v));
+ break;
+ }
+ case RecvTensorResponse::kTransportOptionsFieldNumber: {
+ if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
+ !ReadNestedMessage(&input, meta_.mutable_transport_options()))
+ return false;
+ break;
+ }
+ default: {
+ // Unknown tag, so don't handle we can't handle on the fast path
+ return false;
+ }
+ }
+ }
+
+ return false;
+}
+
+bool TensorResponse::ParseSlow(Source* source) {
+ if (!meta_.ParseFromZeroCopyStream(source->contents())) {
+ return false;
+ }
+
+ Tensor parsed(meta_.tensor().dtype());
+ if (!parsed.FromProto(allocator_, meta_.tensor())) {
+ return false;
+ }
+ tensor_ = std::move(parsed);
+
+ // Reduce memory usage for big tensors.
+ {
+ TensorProto empty;
+ meta_.mutable_tensor()->Swap(&empty);
+ }
+ meta_.clear_tensor();
+
+ return true;
+}
+
+void TensorResponse::Clear() {
+ meta_.Clear();
+ tensor_ = Tensor();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.h b/tensorflow/core/distributed_runtime/tensor_coding.h
new file mode 100644
index 0000000000..e193b0776d
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/tensor_coding.h
@@ -0,0 +1,85 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+class Allocator;
+class TensorProto;
+
+// TensorResponse can be used as the destination of an RPC that returns
+// a RecvTensorResponse. It efficiently decodes the incoming data
+// into Tensor contents as well as associated metadata.
+class TensorResponse {
+ public:
+ explicit TensorResponse(Allocator* allocator);
+
+ // Source provides a way for a particular RPC implementation to provide
+ // received data to ParseFrom.
+ class Source {
+ public:
+ // Return the stream that contains the data to be parsed.
+ // Note that this method might be invoked more than once if
+ // ParseFrom needs to fall back to a more expensive parsing method.
+ // Every call must return a stream pointing at the beginning of
+ // the serialized RecvTensorResponse.
+ //
+ // Note that a subsequent call to contents() invalidates previous
+ // results of contents().
+ //
+ // Ownership of the returned stream is retained by the Source and
+ // should not be deleted by the caller.
+ virtual ::tensorflow::protobuf::io::ZeroCopyInputStream* contents() = 0;
+ };
+
+ // Parse the RecvTensorResponse encoded in the data yielded by
+ // source->contents() into *this.
+ Status ParseFrom(Source* source);
+
+ // Return a reference to the parsed tensor. The tensor will remain
+ // live only until *this is destroyed or modified.
+ const Tensor& tensor() const { return tensor_; }
+
+ // Return a reference to the parsed tensor metadata (no contents).
+ // The result will remain live only until *this is destroyed or
+ // modified.
+ const RecvTensorResponse& metadata() const { return meta_; }
+
+ // Clear contents of *this.
+ void Clear();
+
+ private:
+ bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
+ TensorProto* tensor_meta);
+ bool ParseFast(Source* source);
+ bool ParseSlow(Source* source);
+
+ Allocator* allocator_ = nullptr;
+ bool already_used_ = false;
+ Tensor tensor_;
+ RecvTensorResponse meta_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
diff --git a/tensorflow/core/distributed_runtime/tensor_coding_test.cc b/tensorflow/core/distributed_runtime/tensor_coding_test.cc
new file mode 100644
index 0000000000..0b1d3b6189
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/tensor_coding_test.cc
@@ -0,0 +1,186 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/tensor_coding.h"
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+class StringSource : public TensorResponse::Source {
+ public:
+ explicit StringSource(const string* s, int block_size)
+ : s_(s), stream_(nullptr), block_size_(block_size) {}
+ virtual ~StringSource() { DeleteStream(); }
+
+ protobuf::io::ZeroCopyInputStream* contents() {
+ DeleteStream();
+ stream_ = new (&space_)
+ protobuf::io::ArrayInputStream(s_->data(), s_->size(), block_size_);
+ return stream_;
+ }
+
+ void DeleteStream() {
+ if (stream_) {
+ stream_->~ArrayInputStream();
+ }
+ }
+
+ private:
+ const string* s_;
+ protobuf::io::ArrayInputStream* stream_;
+ char space_[sizeof(protobuf::io::ArrayInputStream)];
+ int block_size_;
+};
+
+class TensorResponseTest : public ::testing::Test {
+ public:
+ void Validate(const Tensor& src, bool is_dead, bool use_tensor_content) {
+ RecvTensorResponse proto;
+ proto.set_is_dead(is_dead);
+ proto.set_send_start_micros(123456);
+ if (use_tensor_content) {
+ src.AsProtoTensorContent(proto.mutable_tensor());
+ } else {
+ src.AsProtoField(proto.mutable_tensor());
+ }
+ string encoded;
+ proto.AppendToString(&encoded);
+
+ StringSource source(&encoded, 1024);
+
+ TensorResponse response(cpu_allocator());
+ for (int i = 0; i < 2; i++) { // Twice so we exercise reuse of "response"
+ Status s = response.ParseFrom(&source);
+ EXPECT_TRUE(s.ok());
+
+ const RecvTensorResponse& meta = response.metadata();
+ EXPECT_EQ(meta.is_dead(), is_dead);
+ EXPECT_EQ(meta.send_start_micros(), 123456);
+
+ const Tensor& result = response.tensor();
+ EXPECT_EQ(result.dtype(), src.dtype());
+ EXPECT_EQ(result.shape().DebugString(), src.shape().DebugString());
+ EXPECT_EQ(result.DebugString(), src.DebugString());
+ }
+ }
+
+ template <typename T>
+ void DoTest(DataType dt) {
+ gtl::InlinedVector<T, 4> v;
+ LOG(ERROR) << "DT: " << static_cast<int>(dt);
+ for (int elems = 0; elems <= 10000; elems++) {
+ if (elems < 100 || (elems % 1000 == 0)) {
+ Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
+ test::FillValues<T>(&a, v);
+ Validate(a, (elems == 0), true);
+ }
+ v.push_back(static_cast<T>(elems));
+ }
+ }
+ void DoTestForStrings(DataType dt) {
+ gtl::InlinedVector<string, 4> v;
+ LOG(ERROR) << "DT: string";
+ for (int elems = 0; elems <= 10000; elems++) {
+ if (elems < 100 || (elems % 1000 == 0)) {
+ Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
+ test::FillValues<string>(&a, v);
+ Validate(a, (elems == 0), true);
+ }
+ v.push_back(strings::StrCat("This is string ", elems));
+ }
+ }
+};
+
+TEST_F(TensorResponseTest, Simple) {
+ DoTest<float>(DT_FLOAT);
+ DoTest<double>(DT_DOUBLE);
+ DoTest<int32>(DT_INT32);
+ DoTest<uint16>(DT_UINT16);
+ DoTest<uint8>(DT_UINT8);
+ DoTest<int16>(DT_INT16);
+ DoTest<int8>(DT_INT8);
+ DoTest<complex64>(DT_COMPLEX64);
+ DoTest<complex128>(DT_COMPLEX128);
+ DoTest<int64>(DT_INT64);
+ DoTest<bool>(DT_BOOL);
+ DoTest<qint8>(DT_QINT8);
+ DoTest<quint8>(DT_QUINT8);
+ DoTest<qint16>(DT_QINT16);
+ DoTest<quint16>(DT_QUINT16);
+ DoTest<qint32>(DT_QINT32);
+ DoTest<bfloat16>(DT_BFLOAT16);
+ DoTest<Eigen::half>(DT_HALF);
+}
+
+TEST_F(TensorResponseTest, StringTensor) { DoTestForStrings(DT_STRING); }
+
+string MakeFloatTensorTestCase(int num_elems) {
+ std::vector<int8> v(num_elems);
+ for (int i = 0; i < num_elems; i++) {
+ v[i] = i % 10;
+ }
+ Tensor src(DT_INT8, TensorShape({1, static_cast<int64>(v.size())}));
+ test::FillValues<int8>(&src, v);
+
+ RecvTensorResponse proto;
+ proto.set_is_dead(false);
+ proto.set_send_start_micros(123456);
+ src.AsProtoTensorContent(proto.mutable_tensor());
+ string encoded;
+ proto.AppendToString(&encoded);
+ return encoded;
+}
+
+static void BM_TensorResponse(int iters, int arg) {
+ testing::StopTiming();
+ string encoded = MakeFloatTensorTestCase(arg);
+ testing::StartTiming();
+ while (--iters > 0) {
+ TensorResponse response(cpu_allocator());
+ StringSource source(&encoded, -1);
+ Status s = response.ParseFrom(&source);
+ if (iters == 1) {
+ testing::SetLabel(
+ strings::StrCat("Bytes: ", response.tensor().TotalBytes()));
+ }
+ }
+}
+BENCHMARK(BM_TensorResponse)->Arg(0)->Arg(1000)->Arg(100000);
+
+static void BM_TensorViaTensorProto(int iters, int arg) {
+ testing::StopTiming();
+ string encoded = MakeFloatTensorTestCase(arg);
+ testing::StartTiming();
+ while (--iters > 0) {
+ RecvTensorResponse r;
+ r.ParseFromString(encoded);
+ Tensor t;
+ CHECK(t.FromProto(r.tensor()));
+ if (iters == 1) {
+ testing::SetLabel(strings::StrCat("Bytes: ", t.TotalBytes()));
+ }
+ }
+}
+BENCHMARK(BM_TensorViaTensorProto)->Arg(0)->Arg(1000)->Arg(100000);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/protobuf.h b/tensorflow/core/platform/default/protobuf.h
index 544f62d1f8..acc804a4ee 100644
--- a/tensorflow/core/platform/default/protobuf.h
+++ b/tensorflow/core/platform/default/protobuf.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map.h"
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/text_format.h"