aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_set.h
blob: f3f7ac0e76d73f165c0115ac7bcf76f505f643c2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// A class to manage slices of a tensor. You can "register" set of slices for a
// tensor and then "query" if we have data for a given slice.

// TODO(yangke): consider moving it to a more private place so that we don't
// need to expose the API.

#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
#define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_

#include <string>                 // for string
#include <unordered_map>

#include "tensorflow/core/platform/port.h"  // for int64
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/public/tensor_shape.h"
#include "tensorflow/core/lib/core/stringpiece.h"  // for StringPiece
#include "tensorflow/core/public/status.h"         // for Status

namespace tensorflow {

namespace checkpoint {

class TensorSliceSet {
 public:
  TensorSliceSet(const TensorShape& shape, DataType type);
  virtual ~TensorSliceSet();

  const TensorShape& shape() const { return shape_; }
  const DataType type() const { return type_; }

  // Register a new slice for the tensor. The "tag" is an arbitrary string
  // associated with the slice (in one application it denotes the name of the
  // file that contains the slice); the "data" points to the data of the tensor
  // slice (it can be a nullptr).
  // We don't take the ownership of "data" and the caller needs to make sure
  // the data is always available during the life time of the tensor slice set
  // if it is not nullptr.
  Status Register(const TensorSlice& slice, const string& tag,
                       const float* data);

  // Query about a new slice: checks if we have data for "slice" and if we have
  // the data and "data" is not nullptr, fill "data" with the slice data. The
  // caller needs to make sure "data" point to a large eough buffer.
  // TODO(yangke): avoid unnecessary copying by using a core::RefCounted
  // pointer.
  bool Query(const TensorSlice& slice, float* data) const;

  // Alternative way of querying about a new slice: instead of copying the
  // data, it returns a list of meta data about the stored slices that will
  // supply data for the slice.
  bool QueryMeta(
      const TensorSlice& slice,
      std::vector<std::pair<tensorflow::TensorSlice, string>>* results) const;

 private:
  const TensorShape shape_;
  const DataType type_;
  struct SliceInfo {
    TensorSlice slice;
    const string tag;
    const float* data;
    int64 num_floats;
  };
  // We maintain a mapping from the slice string to the slice information.
  std::unordered_map<string, SliceInfo> slices_;
};

}  // namespace checkpoint

}  // namespace tensorflow

#endif  // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_