aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_slice.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor_slice.cc')
-rw-r--r--tensorflow/core/framework/tensor_slice.cc226
1 files changed, 226 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_slice.cc b/tensorflow/core/framework/tensor_slice.cc
new file mode 100644
index 0000000000..473d9463ee
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.cc
@@ -0,0 +1,226 @@
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
+
+TensorSlice::TensorSlice(const TensorSliceProto& proto) {
+ starts_.reserve(proto.extent_size());
+ lengths_.reserve(proto.extent_size());
+ for (const auto& e : proto.extent()) {
+ starts_.push_back(e.start());
+ lengths_.push_back(GetExtentLength(e));
+ }
+}
+
+TensorSlice::TensorSlice(std::initializer_list<std::pair<int, int>> extents) {
+ starts_.reserve(extents.size());
+ lengths_.reserve(extents.size());
+ for (const auto& e : extents) {
+ starts_.push_back(e.first);
+ lengths_.push_back(e.second);
+ }
+}
+
+Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
+ std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
+ slice->starts_.reserve(items.size());
+ slice->lengths_.reserve(items.size());
+ for (const string& x : items) {
+ int s, l;
+ if (x == "-") {
+ // "everything"
+ s = 0;
+ l = kFullExtent;
+ } else {
+ char junk;
+ if (sscanf(x.c_str(), "%d,%d%c", &s, &l, &junk) != 2) {
+ return errors::InvalidArgument(
+ "Expected a pair of numbers or '-' "
+ "but got '",
+ x, "': string = ", str);
+ }
+ if (s < 0 || l <= 0) {
+ return errors::InvalidArgument(
+ "Expected non-negative start and "
+ "positive length but got start = ",
+ s, ", length = ", l, ": string = ", str);
+ }
+ }
+ slice->starts_.push_back(s);
+ slice->lengths_.push_back(l);
+ }
+
+ return Status::OK();
+}
+
+void TensorSlice::Clear() {
+ starts_.clear();
+ lengths_.clear();
+}
+
+void TensorSlice::SetFullSlice(int dim) {
+ Clear();
+ starts_.reserve(dim);
+ lengths_.reserve(dim);
+ for (int d = 0; d < dim; ++d) {
+ starts_.push_back(0);
+ lengths_.push_back(kFullExtent);
+ }
+}
+
+void TensorSlice::Extend(int dim) {
+ int old_dim = dims();
+ DCHECK_LE(old_dim, dim);
+ starts_.resize(dim);
+ lengths_.resize(dim);
+ for (int d = old_dim; d < dim; ++d) {
+ starts_[d] = 0;
+ lengths_[d] = kFullExtent;
+ }
+}
+
+void TensorSlice::AsProto(TensorSliceProto* proto) const {
+ for (int d = 0; d < dims(); ++d) {
+ TensorSliceProto::Extent* e = proto->add_extent();
+ // We only need to record the explicit slice for non-full slices
+ if (!IsFullAt(d)) {
+ e->set_start(starts_[d]);
+ e->set_length(lengths_[d]);
+ }
+ }
+}
+
+string TensorSlice::DebugString() const {
+ string buffer;
+ bool first = true;
+ for (int d = 0; d < dims(); ++d) {
+ if (!first) {
+ buffer.append(":");
+ }
+ string s;
+ if (IsFullAt(d)) {
+ buffer.append("-");
+ } else {
+ strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
+ }
+ first = false;
+ }
+ return buffer;
+}
+
+bool TensorSlice::Intersect(const TensorSlice& other,
+ TensorSlice* result) const {
+ // First, if two slices have different ranks, they obviously don't overlap
+ // -- in fact they are not compatible.
+ if (dims() != other.dims()) {
+ return false;
+ }
+
+ // Setting the result to the right dimension
+ if (result) {
+ result->SetFullSlice(dims());
+ }
+ // The two slices overlap if they overlap in all dimensions.
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ if (result) {
+ result->set_start(d, other.start(d));
+ result->set_length(d, other.length(d));
+ }
+ } else if (other.IsFullAt(d)) {
+ if (result) {
+ result->set_start(d, start(d));
+ result->set_length(d, length(d));
+ }
+ } else {
+ // If we have an intersection here, it should have a start that is the
+ // max of the two starts and an end that is the min of the two ends.
+ int s = std::max(start(d), other.start(d));
+ int l = std::min(end(d), other.end(d)) - s;
+ if (l > 0) {
+ // We have a real intersection
+ if (result) {
+ result->set_start(d, s);
+ result->set_length(d, l);
+ }
+ } else {
+ // We don't have an intersection for this dimension -- thus we don't
+ // have any intersection at all.
+ if (result) {
+ result->Clear();
+ }
+ return false;
+ }
+ }
+ }
+ // If we are here, we know there is overlap in every dimension.
+ return true;
+}
+
+void TensorSlice::ComputeRelative(const TensorSlice& sub,
+ TensorSlice* relative) const {
+ DCHECK_EQ(dims(), sub.dims());
+ relative->SetFullSlice(dims());
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ relative->set_start(d, sub.start(d));
+ relative->set_length(d, sub.length(d));
+ } else {
+ // Otherwise the relative start is the difference between the start of
+ // sub and the start of base
+ relative->set_start(d, sub.start(d) - start(d));
+ relative->set_length(d, sub.length(d));
+ }
+ }
+}
+
+// static
+bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
+ return extent.has_length_case() == TensorSliceProto::Extent::kLength;
+}
+
+// static
+int64 TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
+ if (!HasExtentLength(extent)) return -1;
+ return extent.length();
+}
+
+Status TensorSlice::SliceTensorShape(const TensorShape& shape,
+ TensorShape* result_shape) const {
+ result_shape->Clear();
+ // Mismatching ranks: we can't apply the slice at all.
+ if (shape.dims() != dims()) {
+ return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
+ ", slice = ", DebugString());
+ }
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ result_shape->AddDim(shape.dim_size(d));
+ } else {
+ // Check if the extent applies to the dimension
+ if (end(d) <= shape.dim_size(d)) {
+ // Yes: the end is within the range of the dim -- we adjust the result
+ // shape so that its size along this dimension is the length of the
+ // slice.
+ result_shape->AddDim(length(d));
+ } else {
+ // The extent doesn't apply to the dimension
+ result_shape->Clear();
+ return errors::Internal("Extent in dimension ", d,
+ " out of bounds: shape = ", shape.DebugString(),
+ ", slice = ", DebugString());
+ }
+ }
+ }
+ // If we are here, we have successfully applied the shape.
+ return Status::OK();
+}
+
+const int TensorSlice::kFullExtent = -1;
+
+} // namespace tensorflow