/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Utilities for saving/restoring tensor slice checkpoints. #ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ #define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_ #include // for string #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" // for Status #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { namespace checkpoint { // The key for the metadata in the tensor slice checkpoint files. It is "" so // that the metadata is always at the beginning of a checkpoint file. extern const char kSavedTensorSlicesKey[]; // Encode a tensor name + a tensor slice into an ordered code and outputs it as // a string. // The format is // <0> // // // // // ... string EncodeTensorNameSlice(const string& name, const tensorflow::TensorSlice& slice); // Parse out the name and the slice from string encoded as an ordered code. Status DecodeTensorNameSlice(const string& code, string* name, tensorflow::TensorSlice* slice); // Extracts the full shape, slice spec, and shape of the slice from // "shape_and_slice". On non-OK return, caller must clear the out-arguments // before reusing. Status ParseShapeAndSlice(const string& shape_and_slice, TensorShape* shape, TensorSlice* slice, TensorShape* shape_slice); template struct SaveTypeTraits; template const typename SaveTypeTraits::SavedType* TensorProtoData( const TensorProto& t); template typename SaveTypeTraits::RepeatedField* MutableTensorProtoData( TensorProto* t); template void Fill(T* data, size_t n, TensorProto* t); #define TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, STYPE) \ template <> \ struct SaveTypeTraits { \ static constexpr bool supported = true; \ typedef STYPE SavedType; \ typedef protobuf::RepeatedField RepeatedField; \ }; \ template <> \ inline const STYPE* TensorProtoData(const TensorProto& t) { \ static_assert(SaveTypeTraits::supported, \ "Specified type " #TYPE " not supported for Restore"); \ return reinterpret_cast(t.FIELD##_val().data()); \ } \ template <> \ inline protobuf::RepeatedField* MutableTensorProtoData( \ TensorProto * t) { \ static_assert(SaveTypeTraits::supported, \ "Specified type " #TYPE " not supported for Save"); \ return reinterpret_cast*>( \ t->mutable_##FIELD##_val()); \ } #define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \ TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, FTYPE) \ template <> \ inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ typename protobuf::RepeatedField copy(data, data + n); \ t->mutable_##FIELD##_val()->Swap(©); \ } // Complex needs special treatment since proto doesn't have native complex #define TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(TYPE, FIELD, FTYPE) \ TENSOR_PROTO_EXTRACT_TYPE_HELPER(TYPE, FIELD, FTYPE, TYPE) \ template <> \ inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \ const FTYPE* sub = reinterpret_cast(data); \ typename protobuf::RepeatedField copy(sub, sub + 2 * n); \ t->mutable_##FIELD##_val()->Swap(©); \ } TENSOR_PROTO_EXTRACT_TYPE(bool, bool, bool); TENSOR_PROTO_EXTRACT_TYPE(float, float, float); TENSOR_PROTO_EXTRACT_TYPE(double, double, double); TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex64, scomplex, float); TENSOR_PROTO_EXTRACT_TYPE_COMPLEX(complex128, dcomplex, double); TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int64, int64, protobuf_int64); TENSOR_PROTO_EXTRACT_TYPE(uint16, int, int32); TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32); TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32); TENSOR_PROTO_EXTRACT_TYPE(quint16, int, int32); #undef TENSOR_PROTO_EXTRACT_TYPE_COMPLEX #undef TENSOR_PROTO_EXTRACT_TYPE_HELPER #undef TENSOR_PROTO_EXTRACT_TYPE // Custom implementation for qint32, based on the one for int32. template <> struct SaveTypeTraits : SaveTypeTraits {}; template <> inline const int32* TensorProtoData(const TensorProto& t) { static_assert(SaveTypeTraits::supported, "Specified type qint32 not supported for Restore"); return reinterpret_cast(t.int_val().data()); } inline void Fill(const qint32* data, size_t n, TensorProto* t) { const int32* p = reinterpret_cast(data); typename protobuf::RepeatedField copy(p, p + n); t->mutable_int_val()->Swap(©); } // Custom implementation for Eigen::half. template <> struct SaveTypeTraits { static constexpr bool supported = true; typedef int SavedType; typedef protobuf::RepeatedField RepeatedField; }; template <> inline const int* TensorProtoData(const TensorProto& t) { return t.half_val().data(); } template <> inline protobuf::RepeatedField* MutableTensorProtoData( TensorProto* t) { return t->mutable_half_val(); } template <> inline void Fill(const Eigen::half* data, size_t n, TensorProto* t) { typename protobuf::RepeatedField* val = t->mutable_half_val(); val->Resize(n, 0); for (size_t i = 0; i < n; ++i) { val->Set(i, data[i].x); } } // Custom implementation for string. template <> struct SaveTypeTraits { static constexpr bool supported = true; typedef const string* SavedType; typedef protobuf::RepeatedPtrField RepeatedField; }; template <> inline const string* const* TensorProtoData(const TensorProto& t) { static_assert(SaveTypeTraits::supported, "Specified type string not supported for Restore"); return t.string_val().data(); } template <> inline protobuf::RepeatedPtrField* MutableTensorProtoData( TensorProto* t) { static_assert(SaveTypeTraits::supported, "Specified type string not supported for Save"); return t->mutable_string_val(); } template <> inline void Fill(const string* data, size_t n, TensorProto* t) { typename protobuf::RepeatedPtrField copy(data, data + n); t->mutable_string_val()->Swap(©); } } // namespace checkpoint } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_