diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-15 19:34:13 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-16 09:56:16 -0700 |
commit | 0f7c3b0a886300c332fa66df58b4d4a5d477a4d9 (patch) | |
tree | 1480644be31cbbca7d64e44dba12a018948c330a /tensorflow/core/distributed_runtime/tensor_coding.cc | |
parent | 2c5718c133d36440f3dfd005a5d199db342faab8 (diff) |
Update generated Python Op docs.
Change: 130356783
Diffstat (limited to 'tensorflow/core/distributed_runtime/tensor_coding.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/tensor_coding.cc | 72 |
1 files changed, 65 insertions, 7 deletions
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc index 72399c9b11..b44babafa3 100644 --- a/tensorflow/core/distributed_runtime/tensor_coding.cc +++ b/tensorflow/core/distributed_runtime/tensor_coding.cc @@ -14,14 +14,76 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/distributed_runtime/tensor_coding.h" +#include "tensorflow/core/common_runtime/device.h" namespace tensorflow { -TensorResponse::TensorResponse(Allocator* allocator) : allocator_(allocator) {} +TensorResponse::Source::~Source() {} + +void TensorResponse::Clear() { + on_host_ = false; + device_ = nullptr; + alloc_attrs_ = AllocatorAttributes(); + allocator_ = nullptr; + already_used_ = false; + ClearTensor(); +} + +void TensorResponse::ClearTensor() { + meta_.Clear(); + tensor_ = Tensor(); +} + +void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) { + Clear(); + device_ = d; + alloc_attrs_ = aa; + const DeviceAttributes& da = d->attributes(); + if (alloc_attrs_.on_host() || da.device_type() == "CPU") { + on_host_ = true; + } + allocator_ = device_->GetAllocator(alloc_attrs_); +} + +Status TensorResponse::InitFrom(RecvTensorResponse* response) { + Status s; + meta_.Swap(response); + if (on_host_) { + if (!tensor_.FromProto(allocator_, meta_.tensor())) { + s = errors::InvalidArgument("Cannot parse tensor from response"); + } + } else { + s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_); + } + { + TensorProto empty; + meta_.mutable_tensor()->Swap(&empty); + } + meta_.clear_tensor(); + return s; +} + +void TensorResponse::InitPartial(RecvTensorResponse* response) { + // Everything except content is present in *response. Content will + // arrive later; allocate a Tensor with appropriate storage for that + // content. + meta_.Swap(response); + TensorShape shape(meta_.tensor().tensor_shape()); + Tensor t(allocator_, meta_.tensor().dtype(), shape); + tensor_ = std::move(t); +} Status TensorResponse::ParseFrom(Source* source) { + if (!on_host_) { + // Pre-parse into local storage, then delegate to device. + RecvTensorResponse proto; + if (!proto.ParseFromZeroCopyStream(source->contents())) { + return errors::InvalidArgument("Cannot parse tensor from response"); + } + return device_->MakeTensorFromProto(proto.tensor(), alloc_attrs_, &tensor_); + } if (already_used_) { - Clear(); + ClearTensor(); } already_used_ = true; if (ParseFast(source)) return Status::OK(); @@ -140,6 +202,7 @@ bool TensorResponse::ParseTensorSubmessage( bool TensorResponse::ParseFast(Source* source) { protobuf::io::CodedInputStream input(source->contents()); + input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited while (true) { auto p = input.ReadTagWithCutoff(127); int tag = GetTagFieldNumber(p.first); @@ -213,9 +276,4 @@ bool TensorResponse::ParseSlow(Source* source) { return true; } -void TensorResponse::Clear() { - meta_.Clear(); - tensor_ = Tensor(); -} - } // namespace tensorflow |