aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/tensor_coding.h
blob: 4c34297990d399e4e42f5776cd23fb660c9090c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_

#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/worker.pb.h"

namespace tensorflow {

class Allocator;
class DeviceBase;
class TensorProto;

// TensorResponse can be used as the destination of an RPC that returns
// a RecvTensorResponse.  It efficiently decodes the incoming data
// into Tensor contents as well as associated metadata.
class TensorResponse {
 public:
  TensorResponse() {}

  // Reset to initial state.
  void Clear();

  // Clear just tensor_ and meta_ members without setting allocation
  // related members.
  void ClearTensor();

  // Initialize memory allocation related members.
  void InitAlloc(DeviceBase* d, const AllocatorAttributes& aa);

  // Source provides a way for a particular RPC implementation to provide
  // received data to ParseFrom.
  class Source {
   public:
    virtual ~Source();

    // Return the stream that contains the data to be parsed.
    // Note that this method might be invoked more than once if
    // ParseFrom needs to fall back to a more expensive parsing method.
    // Every call must return a stream pointing at the beginning of
    // the serialized RecvTensorResponse.
    //
    // Note that a subsequent call to contents() invalidates previous
    // results of contents().
    //
    // Ownership of the returned stream is retained by the Source and
    // should not be deleted by the caller.
    virtual ::tensorflow::protobuf::io::ZeroCopyInputStream* contents() = 0;
  };

  // Parse the RecvTensorResponse encoded in the data yielded by
  // source->contents() into *this.
  Status ParseFrom(Source* source);

  // Initialize tensor from *response.
  // Leaves *response with unspecified contents.
  Status InitFrom(RecvTensorResponse* response);

  // Initialize tensor metadata from response and allocate
  // uninitialized backing storage for actual contents.
  void InitPartial(const RecvTensorResponse& response);

  // Return a reference to the parsed tensor.  The tensor will remain
  // live only until *this is destroyed or modified.
  const Tensor& tensor() const { return tensor_; }

  // Return a reference to the parsed tensor metadata (no contents).
  // The result will remain live only until *this is destroyed or
  // modified.
  const RecvTensorResponse& metadata() const { return meta_; }

  // Return pointer to the device hosting the tensor.
  DeviceBase* device() const { return device_; }

 private:
  bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
                             TensorProto* tensor_meta);
  bool ParseFast(Source* source);
  bool ParseSlow(Source* source);

  bool on_host_ = false;
  DeviceBase* device_ = nullptr;
  AllocatorAttributes alloc_attrs_;
  Allocator* allocator_ = nullptr;
  bool already_used_ = false;
  Tensor tensor_;
  RecvTensorResponse meta_;
};

}  // namespace tensorflow

#endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_