aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_set.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/tensor_slice_set.h')
-rw-r--r--tensorflow/core/util/tensor_slice_set.h73
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/core/util/tensor_slice_set.h b/tensorflow/core/util/tensor_slice_set.h
new file mode 100644
index 0000000000..f3f7ac0e76
--- /dev/null
+++ b/tensorflow/core/util/tensor_slice_set.h
@@ -0,0 +1,73 @@
+// A class to manage slices of a tensor. You can "register" set of slices for a
+// tensor and then "query" if we have data for a given slice.
+
+// TODO(yangke): consider moving it to a more private place so that we don't
+// need to expose the API.
+
+#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
+
+#include <string> // for string
+#include <unordered_map>
+
+#include "tensorflow/core/platform/port.h" // for int64
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/core/stringpiece.h" // for StringPiece
+#include "tensorflow/core/public/status.h" // for Status
+
+namespace tensorflow {
+
+namespace checkpoint {
+
+class TensorSliceSet {
+ public:
+ TensorSliceSet(const TensorShape& shape, DataType type);
+ virtual ~TensorSliceSet();
+
+ const TensorShape& shape() const { return shape_; }
+ const DataType type() const { return type_; }
+
+ // Register a new slice for the tensor. The "tag" is an arbitrary string
+ // associated with the slice (in one application it denotes the name of the
+ // file that contains the slice); the "data" points to the data of the tensor
+ // slice (it can be a nullptr).
+ // We don't take the ownership of "data" and the caller needs to make sure
+ // the data is always available during the life time of the tensor slice set
+ // if it is not nullptr.
+ Status Register(const TensorSlice& slice, const string& tag,
+ const float* data);
+
+ // Query about a new slice: checks if we have data for "slice" and if we have
+ // the data and "data" is not nullptr, fill "data" with the slice data. The
+ // caller needs to make sure "data" point to a large eough buffer.
+ // TODO(yangke): avoid unnecessary copying by using a core::RefCounted
+ // pointer.
+ bool Query(const TensorSlice& slice, float* data) const;
+
+ // Alternative way of querying about a new slice: instead of copying the
+ // data, it returns a list of meta data about the stored slices that will
+ // supply data for the slice.
+ bool QueryMeta(
+ const TensorSlice& slice,
+ std::vector<std::pair<tensorflow::TensorSlice, string>>* results) const;
+
+ private:
+ const TensorShape shape_;
+ const DataType type_;
+ struct SliceInfo {
+ TensorSlice slice;
+ const string tag;
+ const float* data;
+ int64 num_floats;
+ };
+ // We maintain a mapping from the slice string to the slice information.
+ std::unordered_map<string, SliceInfo> slices_;
+};
+
+} // namespace checkpoint
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_