aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/saved_tensor_slice_util.h
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2015-11-06 16:27:58 -0800
commitf41959ccb2d9d4c722fe8fc3351401d53bcf4900 (patch)
treeef0ca22cb2a5ac4bdec9d080d8e0788a53ed496d /tensorflow/core/util/saved_tensor_slice_util.h
TensorFlow: Initial commit of TensorFlow library.
TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108
Diffstat (limited to 'tensorflow/core/util/saved_tensor_slice_util.h')
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h110
1 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
new file mode 100644
index 0000000000..6206cd8538
--- /dev/null
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -0,0 +1,110 @@
+// Utilities for saving/restoring tensor slice checkpoints.
+
+#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+
+#include <string> // for string
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/status.h" // for Status
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+// The key for the metadata in the tensor slice checkpoint files. It is "" so
+// that the metadata is always at the beginning of a checkpoint file.
+extern const char kSavedTensorSlicesKey[];
+
+// Encode a tensor name + a tensor slice into an ordered code and outputs it as
+// a string.
+// The format is
+// <0>
+// <tensor_name>
+// <rank>
+// <dim-0-start><dim-0-length>
+// <dim-1-start><dim-1-length>
+// ...
+
+string EncodeTensorNameSlice(const string& name,
+ const tensorflow::TensorSlice& slice);
+
+// Parse out the name and the slice from string encoded as an ordered code.
+Status DecodeTensorNameSlice(const string& code, string* name,
+ tensorflow::TensorSlice* slice);
+
+template <typename T>
+struct SaveTypeTraits;
+
+template <typename T>
+const typename SaveTypeTraits<T>::SavedType* TensorProtoData(
+ const TensorProto& t);
+
+template <typename T>
+protobuf::RepeatedField<typename SaveTypeTraits<T>::SavedType>*
+MutableTensorProtoData(TensorProto* t);
+
+template <typename T>
+void Fill(T* data, size_t n, TensorProto* t);
+
+#define TENSOR_PROTO_EXTRACT_TYPE(TYPE, FIELD, FTYPE) \
+ template <> \
+ struct SaveTypeTraits<TYPE> { \
+ static constexpr bool supported = true; \
+ typedef FTYPE SavedType; \
+ }; \
+ template <> \
+ inline const FTYPE* TensorProtoData<TYPE>(const TensorProto& t) { \
+ static_assert(SaveTypeTraits<TYPE>::supported, \
+ "Specified type " #TYPE " not supported for Restore"); \
+ return reinterpret_cast<const FTYPE*>(t.FIELD##_val().data()); \
+ } \
+ template <> \
+ inline protobuf::RepeatedField<FTYPE>* MutableTensorProtoData<TYPE>( \
+ TensorProto * t) { \
+ static_assert(SaveTypeTraits<TYPE>::supported, \
+ "Specified type " #TYPE " not supported for Save"); \
+ return reinterpret_cast<protobuf::RepeatedField<FTYPE>*>( \
+ t->mutable_##FIELD##_val()); \
+ } \
+ template <> \
+ inline void Fill(const TYPE* data, size_t n, TensorProto* t) { \
+ typename protobuf::RepeatedField<FTYPE> copy(data, data + n); \
+ t->mutable_##FIELD##_val()->Swap(&copy); \
+ }
+
+TENSOR_PROTO_EXTRACT_TYPE(float, float, float);
+TENSOR_PROTO_EXTRACT_TYPE(double, double, double);
+TENSOR_PROTO_EXTRACT_TYPE(int32, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int64, int64, int64);
+TENSOR_PROTO_EXTRACT_TYPE(uint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(int16, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(qint8, int, int32);
+TENSOR_PROTO_EXTRACT_TYPE(quint8, int, int32);
+
+#undef TENSOR_PROTO_EXTRACT_TYPE
+
+template <>
+struct SaveTypeTraits<qint32> : SaveTypeTraits<int32> {};
+
+template <>
+inline const int32* TensorProtoData<qint32>(const TensorProto& t) {
+ static_assert(SaveTypeTraits<qint32>::supported,
+ "Specified type qint32 not supported for Restore");
+ return reinterpret_cast<const int32*>(t.int_val().data());
+}
+
+inline void Fill(const qint32* data, size_t n, TensorProto* t) {
+ const int32* p = reinterpret_cast<const int32*>(data);
+ typename protobuf::RepeatedField<int32> copy(p, p + n);
+ t->mutable_int_val()->Swap(&copy);
+}
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_