aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/saved_tensor_slice_util.h
blob: 6206cd85389cd79c1642ba1bb4eb49d9b0c6f996 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// Utilities for saving/restoring tensor slice checkpoints.

#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_

#include <string>  // for string
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/public/status.h"  // for Status

namespace tensorflow {

namespace checkpoint {

// The key for the metadata in the tensor slice checkpoint files. It is "" so
// that the metadata is always at the beginning of a checkpoint file.
extern const char kSavedTensorSlicesKey[];

// Encode a tensor name + a tensor slice into an ordered code and outputs it as
// a string.
// The format is
//  <0>
//  <tensor_name>
//  <rank>
//  <dim-0-start><dim-0-length>
//  <dim-1-start><dim-1-length>
//  ...

string EncodeTensorNameSlice(const string& name,
                             const tensorflow::TensorSlice& slice);

// Parse out the name and the slice from string encoded as an ordered code.
Status DecodeTensorNameSlice(const string& code, string* name,
                             tensorflow::TensorSlice* slice);

template <typename T>
struct SaveTypeTraits;

template <typename T>
const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
    const TensorProto& t);

template <typename T>
protobuf::RepeatedField<typename SaveTypeTraits<T>::SavedType>*
MutableTensorProtoData(TensorProto* t);

template <typename T>
void Fill(T* data, size_t n, TensorProto* t);

#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE)                    \
  template <>                                                            \
  struct SaveTypeTraits<TYPE> {                                          \
    static constexpr bool supported = true;                              \
    typedef FTYPE SavedType;                                             \
  };                                                                     \
  template <>                                                            \
  inline const FTYPE* TensorProtoData<TYPE>(const TensorProto& t) {      \
    static_assert(SaveTypeTraits<TYPE>::supported,                       \
                  "Specified type " #TYPE " not supported for Restore"); \
    return reinterpret_cast<const FTYPE*>(t.FIELD##_val().data());       \
  }                                                                      \
  template <>                                                            \
  inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>(   \
      TensorProto * t) {                                                 \
    static_assert(SaveTypeTraits<TYPE>::supported,                       \
                  "Specified type " #TYPE " not supported for Save");    \
    return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>(            \
        t->mutable_##FIELD##_val());                                     \
  }                                                                      \
  template <>                                                            \
  inline void Fill(const TYPE* data, size_t n, TensorProto* t) {         \
    typename protobuf::RepeatedField<FTYPE> copy(data, data + n);        \
    t->mutable_##FIELD##_val()->Swap(&copy);                             \
  }

TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(int64, int64, int64);
TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);

#undef TENSOR_PROTO_EXTRACT_TYPE

template <>
struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};

template <>
inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
  static_assert(SaveTypeTraits<qint32>::supported,
                "Specified type qint32 not supported for Restore");
  return reinterpret_cast<const int32*>(t.int_val().data());
}

inline void Fill(const qint32* data, size_t n, TensorProto* t) {
  const int32* p = reinterpret_cast<const int32*>(data);
  typename protobuf::RepeatedField<int32> copy(p, p + n);
  t->mutable_int_val()->Swap(&copy);
}

}  // namespace checkpoint

}  // namespace tensorflow

#endif  // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_