aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-28 09:16:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-28 10:31:06 -0700
commit4f566ce3fd448c8c0c51873385875eff893f8ca6 (patch)
tree87055c08428a5f254c26f8589fa33eff63450d21
parentb8e84f5237a3dcd7bd273ba4537c14eca0e0e667 (diff)
Added new TensorResponse module, which parses a serialized RecvTensorResponse
directly from a protobuf::io::ZeroCopyInputStream into the metadata part of the tensor (stored in a RecvTensorResponse protocol object), and the actual Tensor data, stored directly in a Tensor object. This bypasses the extra copy that was happening through a TensorProto object for the actual tensor data. For large tensors, this is considerably faster. Added a test and a benchmark for this functionality. TensorResponse is the time to go through this new module, while TensorViaTensorProto is the more straightforward path that does an extra copy. For large tensors, this is 2.5X faster. Run on XXXX (40 X 2801 MHz CPUs); 2016-07-27T11:55:08.623471233-07:00 CPU: Intel Ivybridge with HyperThreading (20 cores) dL1:32KB dL2:256KB dL3:25MB Benchmark Time(ns) CPU(ns) Iterations ------------------------------------------------------------- BM_TensorResponse/0 301 301 2322471 Bytes: 0 BM_TensorResponse/1000 401 401 1749405 Bytes: 1000 BM_TensorResponse/98k 5391 5397 100000 Bytes: 100000 BM_TensorViaTensorProto/0 279 279 2531623 Bytes: 0 BM_TensorViaTensorProto/1000 551 551 1000000 Bytes: 1000 BM_TensorViaTensorProto/98k 13692 13706 52313 Bytes: 100000 Co-author=sanjay Change: 128712249
-rw-r--r--tensorflow/core/distributed_runtime/BUILD23
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.cc221
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.h85
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding_test.cc186
-rw-r--r--tensorflow/core/platform/default/protobuf.h1
5 files changed, 515 insertions, 1 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 51bb908213..244ba8bb7e 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -51,9 +51,14 @@ cc_library(
cc_library(
name = "worker_interface",
- hdrs = ["worker_interface.h"],
+ srcs = ["tensor_coding.cc"],
+ hdrs = [
+ "tensor_coding.h",
+ "worker_interface.h",
+ ],
deps = [
":call_options",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:worker_proto_cc",
],
@@ -79,6 +84,22 @@ cc_test(
],
)
+cc_test(
+ name = "tensor_coding_test",
+ size = "small",
+ srcs = ["tensor_coding_test.cc"],
+ linkstatic = 1,
+ deps = [
+ ":worker_interface",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
cc_library(
name = "worker_cache",
hdrs = ["worker_cache.h"],
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc
new file mode 100644
index 0000000000..72399c9b11
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/tensor_coding.cc
@@ -0,0 +1,221 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/tensor_coding.h"
+
+namespace tensorflow {
+
+TensorResponse::TensorResponse(Allocator* allocator) : allocator_(allocator) {}
+
+Status TensorResponse::ParseFrom(Source* source) {
+ if (already_used_) {
+ Clear();
+ }
+ already_used_ = true;
+ if (ParseFast(source)) return Status::OK();
+ meta_.Clear();
+ if (ParseSlow(source)) return Status::OK();
+ return errors::InvalidArgument("Cannot parse tensor from response");
+}
+
+// Define some helper routines for decoding protocol buffer wire format data
+namespace {
+// We only need some of the wiretype values for this code
+enum WireType {
+ WIRETYPE_VARINT = 0,
+ WIRETYPE_LENGTH_DELIMITED = 2,
+};
+inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; }
+inline WireType GetTagWireType(uint32 tag) {
+ return static_cast<WireType>(tag & 0x7);
+}
+
+bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) {
+ uint64 v;
+ if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) {
+ *result = static_cast<int>(v);
+ return true;
+ } else {
+ return false;
+ }
+}
+
+bool ReadNestedMessage(protobuf::io::CodedInputStream* input,
+ protobuf::Message* value) {
+ int length;
+ if (!ReadVarintSizeAsInt(input, &length)) return false;
+ std::pair<protobuf::io::CodedInputStream::Limit, int> p =
+ input->IncrementRecursionDepthAndPushLimit(length);
+ if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false;
+ // Make sure that parsing stopped when the limit was hit, not at an endgroup
+ // tag.
+ return input->DecrementRecursionDepthAndPopLimit(p.first);
+}
+
+} // namespace
+
+bool TensorResponse::ParseTensorSubmessage(
+ protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) {
+ bool seen_tensor_content = false;
+ while (true) {
+ auto p = input->ReadTagWithCutoff(127);
+ int tag = GetTagFieldNumber(p.first);
+ WireType wt = GetTagWireType(p.first);
+ if (!p.second) {
+ bool ok = (tag == 0);
+ if (ok && !seen_tensor_content) {
+ // No tensor content: could be because it's a zero-length tensor
+ TensorShape shape(tensor_meta->tensor_shape());
+ Tensor t(allocator_, tensor_meta->dtype(), shape);
+ tensor_ = std::move(t);
+ }
+ return ok;
+ }
+ switch (tag) {
+ case TensorProto::kDtypeFieldNumber: {
+ uint32 v;
+ if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
+ if (seen_tensor_content) return false;
+ tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v)));
+ if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false;
+ break;
+ }
+ case TensorProto::kTensorShapeFieldNumber: {
+ if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
+ !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape()))
+ return false;
+ if (seen_tensor_content) return false;
+ break;
+ }
+ case TensorProto::kVersionNumberFieldNumber: {
+ uint32 v;
+ if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
+ if (seen_tensor_content) return false;
+ tensor_meta->set_version_number(static_cast<int32>(v));
+ break;
+ }
+ case TensorProto::kTensorContentFieldNumber: {
+ // If we haven't seen the dtype and tensor_shape data first, we can't
+ // deal with this in the fast path.
+ if (seen_tensor_content) return false;
+ if (wt != WIRETYPE_LENGTH_DELIMITED ||
+ !tensor_meta->has_tensor_shape()) {
+ return false;
+ }
+ int num_bytes;
+ if (!ReadVarintSizeAsInt(input, &num_bytes)) return false;
+ seen_tensor_content = true;
+ TensorShape shape(tensor_meta->tensor_shape());
+ Tensor t(allocator_, tensor_meta->dtype(), shape);
+ StringPiece buf = t.tensor_data();
+ if (num_bytes != buf.size()) return false;
+ // TODO(jeff,sanjay): Figure out a way to avoid this copy if
+ // the underlying ZeroCopyInputStream data is properly aligned
+ // and compatible with what allocator_ wants.
+ if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes))
+ return false;
+ tensor_ = std::move(t);
+ break;
+ }
+ default: {
+ // Some other tag our fast path code is not prepared to handle.
+ // return false.
+ return false;
+ }
+ }
+ }
+}
+
+bool TensorResponse::ParseFast(Source* source) {
+ protobuf::io::CodedInputStream input(source->contents());
+ while (true) {
+ auto p = input.ReadTagWithCutoff(127);
+ int tag = GetTagFieldNumber(p.first);
+ WireType wt = GetTagWireType(p.first);
+ if (!p.second) {
+ return (tag == 0);
+ }
+ switch (tag) {
+ case RecvTensorResponse::kTensorFieldNumber: {
+ if (wt != WIRETYPE_LENGTH_DELIMITED) return false;
+
+ int length;
+ if (!ReadVarintSizeAsInt(&input, &length)) return false;
+ std::pair<protobuf::io::CodedInputStream::Limit, int> p =
+ input.IncrementRecursionDepthAndPushLimit(length);
+ if (p.second < 0 ||
+ !ParseTensorSubmessage(&input, meta_.mutable_tensor())) {
+ return false;
+ }
+ if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
+ return false;
+ }
+ break;
+ }
+ case RecvTensorResponse::kIsDeadFieldNumber: {
+ uint32 v;
+ if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
+ meta_.set_is_dead((v != 0) ? true : false);
+ break;
+ }
+ case RecvTensorResponse::kSendStartMicrosFieldNumber: {
+ uint64 v;
+ if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false;
+ meta_.set_send_start_micros(static_cast<int64>(v));
+ break;
+ }
+ case RecvTensorResponse::kTransportOptionsFieldNumber: {
+ if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
+ !ReadNestedMessage(&input, meta_.mutable_transport_options()))
+ return false;
+ break;
+ }
+ default: {
+ // Unknown tag, so don't handle we can't handle on the fast path
+ return false;
+ }
+ }
+ }
+
+ return false;
+}
+
+bool TensorResponse::ParseSlow(Source* source) {
+ if (!meta_.ParseFromZeroCopyStream(source->contents())) {
+ return false;
+ }
+
+ Tensor parsed(meta_.tensor().dtype());
+ if (!parsed.FromProto(allocator_, meta_.tensor())) {
+ return false;
+ }
+ tensor_ = std::move(parsed);
+
+ // Reduce memory usage for big tensors.
+ {
+ TensorProto empty;
+ meta_.mutable_tensor()->Swap(&empty);
+ }
+ meta_.clear_tensor();
+
+ return true;
+}
+
+void TensorResponse::Clear() {
+ meta_.Clear();
+ tensor_ = Tensor();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.h b/tensorflow/core/distributed_runtime/tensor_coding.h
new file mode 100644
index 0000000000..e193b0776d
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/tensor_coding.h
@@ -0,0 +1,85 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+class Allocator;
+class TensorProto;
+
+// TensorResponse can be used as the destination of an RPC that returns
+// a RecvTensorResponse. It efficiently decodes the incoming data
+// into Tensor contents as well as associated metadata.
+class TensorResponse {
+ public:
+ explicit TensorResponse(Allocator* allocator);
+
+ // Source provides a way for a particular RPC implementation to provide
+ // received data to ParseFrom.
+ class Source {
+ public:
+ // Return the stream that contains the data to be parsed.
+ // Note that this method might be invoked more than once if
+ // ParseFrom needs to fall back to a more expensive parsing method.
+ // Every call must return a stream pointing at the beginning of
+ // the serialized RecvTensorResponse.
+ //
+ // Note that a subsequent call to contents() invalidates previous
+ // results of contents().
+ //
+ // Ownership of the returned stream is retained by the Source and
+ // should not be deleted by the caller.
+ virtual ::tensorflow::protobuf::io::ZeroCopyInputStream* contents() = 0;
+ };
+
+ // Parse the RecvTensorResponse encoded in the data yielded by
+ // source->contents() into *this.
+ Status ParseFrom(Source* source);
+
+ // Return a reference to the parsed tensor. The tensor will remain
+ // live only until *this is destroyed or modified.
+ const Tensor& tensor() const { return tensor_; }
+
+ // Return a reference to the parsed tensor metadata (no contents).
+ // The result will remain live only until *this is destroyed or
+ // modified.
+ const RecvTensorResponse& metadata() const { return meta_; }
+
+ // Clear contents of *this.
+ void Clear();
+
+ private:
+ bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
+ TensorProto* tensor_meta);
+ bool ParseFast(Source* source);
+ bool ParseSlow(Source* source);
+
+ Allocator* allocator_ = nullptr;
+ bool already_used_ = false;
+ Tensor tensor_;
+ RecvTensorResponse meta_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
diff --git a/tensorflow/core/distributed_runtime/tensor_coding_test.cc b/tensorflow/core/distributed_runtime/tensor_coding_test.cc
new file mode 100644
index 0000000000..0b1d3b6189
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/tensor_coding_test.cc
@@ -0,0 +1,186 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/tensor_coding.h"
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+class StringSource : public TensorResponse::Source {
+ public:
+ explicit StringSource(const string* s, int block_size)
+ : s_(s), stream_(nullptr), block_size_(block_size) {}
+ virtual ~StringSource() { DeleteStream(); }
+
+ protobuf::io::ZeroCopyInputStream* contents() {
+ DeleteStream();
+ stream_ = new (&space_)
+ protobuf::io::ArrayInputStream(s_->data(), s_->size(), block_size_);
+ return stream_;
+ }
+
+ void DeleteStream() {
+ if (stream_) {
+ stream_->~ArrayInputStream();
+ }
+ }
+
+ private:
+ const string* s_;
+ protobuf::io::ArrayInputStream* stream_;
+ char space_[sizeof(protobuf::io::ArrayInputStream)];
+ int block_size_;
+};
+
+class TensorResponseTest : public ::testing::Test {
+ public:
+ void Validate(const Tensor& src, bool is_dead, bool use_tensor_content) {
+ RecvTensorResponse proto;
+ proto.set_is_dead(is_dead);
+ proto.set_send_start_micros(123456);
+ if (use_tensor_content) {
+ src.AsProtoTensorContent(proto.mutable_tensor());
+ } else {
+ src.AsProtoField(proto.mutable_tensor());
+ }
+ string encoded;
+ proto.AppendToString(&encoded);
+
+ StringSource source(&encoded, 1024);
+
+ TensorResponse response(cpu_allocator());
+ for (int i = 0; i < 2; i++) { // Twice so we exercise reuse of "response"
+ Status s = response.ParseFrom(&source);
+ EXPECT_TRUE(s.ok());
+
+ const RecvTensorResponse& meta = response.metadata();
+ EXPECT_EQ(meta.is_dead(), is_dead);
+ EXPECT_EQ(meta.send_start_micros(), 123456);
+
+ const Tensor& result = response.tensor();
+ EXPECT_EQ(result.dtype(), src.dtype());
+ EXPECT_EQ(result.shape().DebugString(), src.shape().DebugString());
+ EXPECT_EQ(result.DebugString(), src.DebugString());
+ }
+ }
+
+ template <typename T>
+ void DoTest(DataType dt) {
+ gtl::InlinedVector<T, 4> v;
+ LOG(ERROR) << "DT: " << static_cast<int>(dt);
+ for (int elems = 0; elems <= 10000; elems++) {
+ if (elems < 100 || (elems % 1000 == 0)) {
+ Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
+ test::FillValues<T>(&a, v);
+ Validate(a, (elems == 0), true);
+ }
+ v.push_back(static_cast<T>(elems));
+ }
+ }
+ void DoTestForStrings(DataType dt) {
+ gtl::InlinedVector<string, 4> v;
+ LOG(ERROR) << "DT: string";
+ for (int elems = 0; elems <= 10000; elems++) {
+ if (elems < 100 || (elems % 1000 == 0)) {
+ Tensor a(dt, TensorShape({1, static_cast<int64>(v.size())}));
+ test::FillValues<string>(&a, v);
+ Validate(a, (elems == 0), true);
+ }
+ v.push_back(strings::StrCat("This is string ", elems));
+ }
+ }
+};
+
+TEST_F(TensorResponseTest, Simple) {
+ DoTest<float>(DT_FLOAT);
+ DoTest<double>(DT_DOUBLE);
+ DoTest<int32>(DT_INT32);
+ DoTest<uint16>(DT_UINT16);
+ DoTest<uint8>(DT_UINT8);
+ DoTest<int16>(DT_INT16);
+ DoTest<int8>(DT_INT8);
+ DoTest<complex64>(DT_COMPLEX64);
+ DoTest<complex128>(DT_COMPLEX128);
+ DoTest<int64>(DT_INT64);
+ DoTest<bool>(DT_BOOL);
+ DoTest<qint8>(DT_QINT8);
+ DoTest<quint8>(DT_QUINT8);
+ DoTest<qint16>(DT_QINT16);
+ DoTest<quint16>(DT_QUINT16);
+ DoTest<qint32>(DT_QINT32);
+ DoTest<bfloat16>(DT_BFLOAT16);
+ DoTest<Eigen::half>(DT_HALF);
+}
+
+TEST_F(TensorResponseTest, StringTensor) { DoTestForStrings(DT_STRING); }
+
+string MakeFloatTensorTestCase(int num_elems) {
+ std::vector<int8> v(num_elems);
+ for (int i = 0; i < num_elems; i++) {
+ v[i] = i % 10;
+ }
+ Tensor src(DT_INT8, TensorShape({1, static_cast<int64>(v.size())}));
+ test::FillValues<int8>(&src, v);
+
+ RecvTensorResponse proto;
+ proto.set_is_dead(false);
+ proto.set_send_start_micros(123456);
+ src.AsProtoTensorContent(proto.mutable_tensor());
+ string encoded;
+ proto.AppendToString(&encoded);
+ return encoded;
+}
+
+static void BM_TensorResponse(int iters, int arg) {
+ testing::StopTiming();
+ string encoded = MakeFloatTensorTestCase(arg);
+ testing::StartTiming();
+ while (--iters > 0) {
+ TensorResponse response(cpu_allocator());
+ StringSource source(&encoded, -1);
+ Status s = response.ParseFrom(&source);
+ if (iters == 1) {
+ testing::SetLabel(
+ strings::StrCat("Bytes: ", response.tensor().TotalBytes()));
+ }
+ }
+}
+BENCHMARK(BM_TensorResponse)->Arg(0)->Arg(1000)->Arg(100000);
+
+static void BM_TensorViaTensorProto(int iters, int arg) {
+ testing::StopTiming();
+ string encoded = MakeFloatTensorTestCase(arg);
+ testing::StartTiming();
+ while (--iters > 0) {
+ RecvTensorResponse r;
+ r.ParseFromString(encoded);
+ Tensor t;
+ CHECK(t.FromProto(r.tensor()));
+ if (iters == 1) {
+ testing::SetLabel(strings::StrCat("Bytes: ", t.TotalBytes()));
+ }
+ }
+}
+BENCHMARK(BM_TensorViaTensorProto)->Arg(0)->Arg(1000)->Arg(100000);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/protobuf.h b/tensorflow/core/platform/default/protobuf.h
index 544f62d1f8..acc804a4ee 100644
--- a/tensorflow/core/platform/default/protobuf.h
+++ b/tensorflow/core/platform/default/protobuf.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "google/protobuf/descriptor.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
+#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/map.h"
#include "google/protobuf/repeated_field.h"
#include "google/protobuf/text_format.h"