// Utilities for saving/restoring tensor slice checkpoints. #ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ #define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ #include // for string #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/public/status.h" // for Status namespace tensorflow { namespace checkpoint { // The key for the metadata in the tensor slice checkpoint files. It is "" so // that the metadata is always at the beginning of a checkpoint file. extern const char kSavedTensorSlicesKey[]; // Encode a tensor name + a tensor slice into an ordered code and outputs it as // a string. // The format is // <0> // // // // // ... string EncodeTensorNameSlice(const string& name, const tensorflow::TensorSlice& slice); // Parse out the name and the slice from string encoded as an ordered code. Status DecodeTensorNameSlice(const string& code, string* name, tensorflow::TensorSlice* slice); template struct SaveTypeTraits; template const typename SaveTypeTraits::SavedType* TensorProtoData( const TensorProto& t); template protobuf::RepeatedField::SavedType>* MutableTensorProtoData(TensorProto* t); template void Fill(T* data, size_t n, TensorProto* t); #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \ template <> \ struct SaveTypeTraits { \ static constexpr bool supported = true; \ typedef FTYPE SavedType; \ }; \ template <> \ inline const FTYPE* TensorProtoData(const TensorProto& t) { \ static_assert(SaveTypeTraits::supported, \ "Specified type " #TYPE " not supported for Restore"); \ return reinterpret_cast(t.FIELD##_val().data()); \ } \ template <> \ inline protobuf::RepeatedField* MutableTensorProtoData( \ TensorProto * t) { \ static_assert(SaveTypeTraits::supported, \ "Specified type " #TYPE " not supported for Save"); \ return reinterpret_cast*>( \ t->mutable_##FIELD##_val()); \ } \ template <> \ inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ typename protobuf::RepeatedField copy(data, data + n); \ t->mutable_##FIELD##_val()->Swap(©); \ } TENSOR_PROTO_EXTRACT_TYPE(float, float, float); TENSOR_PROTO_EXTRACT_TYPE(double, double, double); TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int64, int64, int64); TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); #undef TENSOR_PROTO_EXTRACT_TYPE template <> struct SaveTypeTraits : SaveTypeTraits {}; template <> inline const int32* TensorProtoData(const TensorProto& t) { static_assert(SaveTypeTraits::supported, "Specified type qint32 not supported for Restore"); return reinterpret_cast(t.int_val().data()); } inline void Fill(const qint32* data, size_t n, TensorProto* t) { const int32* p = reinterpret_cast(data); typename protobuf::RepeatedField copy(p, p + n); t->mutable_int_val()->Swap(©); } } // namespace checkpoint } // namespace tensorflow #endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_