aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/tensor_coding.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-15 19:34:13 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-16 09:56:16 -0700
commit0f7c3b0a886300c332fa66df58b4d4a5d477a4d9 (patch)
tree1480644be31cbbca7d64e44dba12a018948c330a /tensorflow/core/distributed_runtime/tensor_coding.cc
parent2c5718c133d36440f3dfd005a5d199db342faab8 (diff)
Update generated Python Op docs.
Change: 130356783
Diffstat (limited to 'tensorflow/core/distributed_runtime/tensor_coding.cc')
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.cc72
1 files changed, 65 insertions, 7 deletions
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc
index 72399c9b11..b44babafa3 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding.cc
+++ b/tensorflow/core/distributed_runtime/tensor_coding.cc
@@ -14,14 +14,76 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/tensor_coding.h"
+#include "tensorflow/core/common_runtime/device.h"
namespace tensorflow {
-TensorResponse::TensorResponse(Allocator* allocator) : allocator_(allocator) {}
+TensorResponse::Source::~Source() {}
+
+void TensorResponse::Clear() {
+ on_host_ = false;
+ device_ = nullptr;
+ alloc_attrs_ = AllocatorAttributes();
+ allocator_ = nullptr;
+ already_used_ = false;
+ ClearTensor();
+}
+
+void TensorResponse::ClearTensor() {
+ meta_.Clear();
+ tensor_ = Tensor();
+}
+
+void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) {
+ Clear();
+ device_ = d;
+ alloc_attrs_ = aa;
+ const DeviceAttributes& da = d->attributes();
+ if (alloc_attrs_.on_host() || da.device_type() == "CPU") {
+ on_host_ = true;
+ }
+ allocator_ = device_->GetAllocator(alloc_attrs_);
+}
+
+Status TensorResponse::InitFrom(RecvTensorResponse* response) {
+ Status s;
+ meta_.Swap(response);
+ if (on_host_) {
+ if (!tensor_.FromProto(allocator_, meta_.tensor())) {
+ s = errors::InvalidArgument("Cannot parse tensor from response");
+ }
+ } else {
+ s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
+ }
+ {
+ TensorProto empty;
+ meta_.mutable_tensor()->Swap(&empty);
+ }
+ meta_.clear_tensor();
+ return s;
+}
+
+void TensorResponse::InitPartial(RecvTensorResponse* response) {
+ // Everything except content is present in *response. Content will
+ // arrive later; allocate a Tensor with appropriate storage for that
+ // content.
+ meta_.Swap(response);
+ TensorShape shape(meta_.tensor().tensor_shape());
+ Tensor t(allocator_, meta_.tensor().dtype(), shape);
+ tensor_ = std::move(t);
+}
Status TensorResponse::ParseFrom(Source* source) {
+ if (!on_host_) {
+ // Pre-parse into local storage, then delegate to device.
+ RecvTensorResponse proto;
+ if (!proto.ParseFromZeroCopyStream(source->contents())) {
+ return errors::InvalidArgument("Cannot parse tensor from response");
+ }
+ return device_->MakeTensorFromProto(proto.tensor(), alloc_attrs_, &tensor_);
+ }
if (already_used_) {
- Clear();
+ ClearTensor();
}
already_used_ = true;
if (ParseFast(source)) return Status::OK();
@@ -140,6 +202,7 @@ bool TensorResponse::ParseTensorSubmessage(
bool TensorResponse::ParseFast(Source* source) {
protobuf::io::CodedInputStream input(source->contents());
+ input.SetTotalBytesLimit(INT_MAX, INT_MAX); // Unlimited
while (true) {
auto p = input.ReadTagWithCutoff(127);
int tag = GetTagFieldNumber(p.first);
@@ -213,9 +276,4 @@ bool TensorResponse::ParseSlow(Source* source) {
return true;
}
-void TensorResponse::Clear() {
- meta_.Clear();
- tensor_ = Tensor();
-}
-
} // namespace tensorflow