aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/tensor_slice_set.cc
blob: 765686f189fb4ae48d1ef1ad51bfa7684965484b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include "tensorflow/core/util/tensor_slice_set.h"

#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/tensor_slice_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"

namespace tensorflow {

namespace checkpoint {

TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
    : shape_(shape), type_(type) {}

TensorSliceSet::~TensorSliceSet() {}

Status TensorSliceSet::Register(const TensorSlice& slice,
                                     const string& tag, const float* data) {
  TensorShape result_shape;
  TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
  string str = slice.DebugString();
  // We check if there is any intersection between this slice and any of the
  // registered slices.
  for (const auto x : slices_) {
    if (slice.Overlaps(x.second.slice)) {
      return errors::Internal("Overlapping slices: existing slice = ", x.first,
                              ", new slice = ", str);
    }
  }
  // No overlap: we can now insert the slice
  TensorSliceSet::SliceInfo info = {slice, tag, data,
                                    result_shape.num_elements()};
  slices_.insert(std::make_pair(str, info));
  return Status::OK();
}

// TODO(yangke): merge Query() with QueryMeta()
bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
  Status s;
  string str = slice.DebugString();
  // First we check if there is an exactly match (this is the dominant case).
  const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
  if (info) {
    if (data) {
      std::copy_n(info->data, info->num_floats, data);
    }
    return true;
  } else {
    // We didn't find any exact match but there is still a posibility that
    // mutliple existing slices can be patched together to output the slice.
    // We figure this out by computing the intersection of each of the existing
    // slices with the query slice, and check if the union of all these
    // intersections cover the entire slice. We rely on the fact that the
    // existing slices don't have any intersection among themselves.
    TensorShape target_shape;
    Status s;
    s = slice.SliceTensorShape(shape_, &target_shape);
    if (!s.ok()) {
      LOG(WARNING) << s;
      return false;
    }
    int64 total_size = target_shape.num_elements();

    int64 overlap_size = 0;
    TensorSlice intersection;
    TensorShape inter_shape;
    for (const auto x : slices_) {
      if (slice.Intersect(x.second.slice, &intersection)) {
        s = intersection.SliceTensorShape(shape_, &inter_shape);
        if (!s.ok()) {
          LOG(WARNING) << s;
          return false;
        }
        overlap_size += inter_shape.num_elements();
      }
    }
    if (total_size == overlap_size) {
      // We have it!
      // Now we need to copy the data to "data"
      if (data) {
        for (const auto x : slices_) {
          CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
                                               x.second.data, data);
        }
      }
      return true;
    } else {
      // We don't have all the data for the asked tensor slice
      return false;
    }
  }
}

bool TensorSliceSet::QueryMeta(
    const TensorSlice& slice,
    std::vector<std::pair<TensorSlice, string>>* results) const {
  results->clear();
  Status s;
  string str = slice.DebugString();
  // First we check if there is an exactly match (this is the dominant case).
  const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
  if (info) {
    results->emplace_back(std::make_pair(info->slice, info->tag));
    return true;
  } else {
    // We didn't find any exact match but there is still a posibility that
    // multiple existing slices can be patched together to output the slice.
    // We figure this out by computing the intersection of each of the existing
    // slices with the query slice, and check if the union of all these
    // intersections cover the entire slice. We rely on the fact that the
    // existing slices don't have any intersection among themselves.
    TensorShape target_shape;
    Status s;
    s = slice.SliceTensorShape(shape_, &target_shape);
    if (!s.ok()) {
      LOG(WARNING) << s;
      return false;
    }
    int64 total_size = target_shape.num_elements();

    int64 overlap_size = 0;
    TensorSlice intersection;
    TensorShape inter_shape;
    for (const auto x : slices_) {
      if (slice.Intersect(x.second.slice, &intersection)) {
        s = intersection.SliceTensorShape(shape_, &inter_shape);
        if (!s.ok()) {
          LOG(WARNING) << s;
          return false;
        }
        overlap_size += inter_shape.num_elements();
        results->emplace_back(std::make_pair(x.second.slice, x.second.tag));
      }
    }
    if (total_size == overlap_size) {
      // We have it!
      return true;
    } else {
      // We don't have all the data for the asked tensor slice
      results->clear();
      return false;
    }
  }
}

}  // namespace checkpoint

}  // namespace tensorflow