aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_slice.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/tensor_slice.h')
-rw-r--r--tensorflow/core/framework/tensor_slice.h189
1 files changed, 189 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h
new file mode 100644
index 0000000000..8e2f108c3f
--- /dev/null
+++ b/tensorflow/core/framework/tensor_slice.h
@@ -0,0 +1,189 @@
+#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+
+#include <string>
+#include "tensorflow/core/framework/tensor_slice.pb.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/tensor_shape.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/public/status.h"
+
+namespace tensorflow {
+
+// A tensor slice represents a slice of a given tensor. It is represented by a
+// list of (start, length) pairs, where the size of the list is the rank of the
+// tensor.
+
+class TensorSlice {
+ public:
+ // Construct a tensor slice: you have a number of ways:
+ // -- creating an empty slice
+ // -- from just a dimension (in this case it will create a full slice)
+ // -- from an array of pairs of integers.
+ // -- from a TensorSliceProto protocol buffer
+ // -- from a string format of "start,lenth:start,length..." where each
+ // "start,length" pair represents the slice on one dimension. We allow a
+ // special "-" that means "everything for this dimension". One such example
+ // is: 0,10:-:14,1:-:-
+ TensorSlice() {}
+ explicit TensorSlice(int dim);
+ explicit TensorSlice(const TensorSliceProto& proto);
+ explicit TensorSlice(std::initializer_list<std::pair<int, int>> extents);
+
+ static Status Parse(const string& str, TensorSlice* output);
+ static TensorSlice ParseOrDie(const string& str) {
+ TensorSlice ret;
+ Status s = Parse(str, &ret);
+ if (!s.ok()) {
+ LOG(FATAL) << "Could not parse TensorSlice";
+ }
+ return ret;
+ }
+
+ void Clear();
+
+ // Accessors
+ int dims() const { return starts_.size(); }
+
+ int start(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return starts_[d];
+ }
+
+ int length(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return lengths_[d];
+ }
+
+ int end(int d) const {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ return start(d) + length(d);
+ }
+
+ void set_start(int d, int x) {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ DCHECK_GE(x, 0);
+ starts_[d] = x;
+ }
+
+ void set_length(int d, int x) {
+ DCHECK_GE(d, 0);
+ DCHECK_LT(d, dims());
+ lengths_[d] = x;
+ }
+
+ // If we have a full slice along dimension "d".
+ bool IsFullAt(int d) const { return lengths_[d] < 0; }
+
+ // Set the slice to be a full slice of "dim" dimensions
+ void SetFullSlice(int dim);
+
+ // Extend a slice to "dim" dimensions: all the added dimensions are full.
+ // Requires: dim >= dims().
+ void Extend(int dim);
+
+ // Conversion of a TensorSlice to other formats
+ void AsProto(TensorSliceProto* proto) const;
+ string DebugString() const;
+
+ // Fill *indices and *sizes from *this (so that we can use the slice()
+ // function in eigen tensor). We need a tensor shape in case some of the
+ // slices are full slices.
+ // We allow NDIMS to be greater than dims(), in which case we will pad the
+ // higher dimensions with trivial dimensions.
+ template <int NDIMS>
+ void FillIndicesAndSizes(const TensorShape& shape,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const;
+
+ // Interaction with other TensorSlices.
+
+ // Compute the intersection with another slice and if "result" is not
+ // nullptr, store the results in *result; returns true is there is any real
+ // intersection.
+ bool Intersect(const TensorSlice& other, TensorSlice* result) const;
+ // A short hand.
+ bool Overlaps(const TensorSlice& other) const {
+ return Intersect(other, nullptr);
+ }
+
+ // Interaction with TensorShape.
+
+ // Slices a shape and stores the result into *result_shape.
+ // Requires that the shape and *this have the same rank.
+ // For example, given a tensor shape of {3, 4, 5}, and a slice of
+ // 1,2:-:0,2, the result shape is {2, 4, 2}.
+ Status SliceTensorShape(const TensorShape& shape,
+ TensorShape* result_shape) const;
+
+ // Given slice "sub" where "sub" is fully contained in *this,
+ // (meaning that the intersection of "sub" and *this equals "sub"), computes
+ // the "relative" slice of "sub" with respect to *this.
+ //
+ // In other words, if we use A>S to denote slicing a shape S with a slice A,
+ // then the function is computing a slice X such that:
+ // X > (this > S) = sub > S
+ // for any shape S.
+ //
+ // In general, along every dimension, the start of the relative slice is the
+ // start of the "sub" slice minus the start of *this; the length of the
+ // relative slice is the length of the "sub" slice.
+ //
+ // For example, say we have a shape of {3, 4, 5}, "this" is 0,2:-:1,2, and
+ // "sub" is 1,1:2:2,1,2, then the related slice is 1,1:2,2:0,2.
+ //
+ // The caller needs to make sure that "sub" is indeed a sub-slice of *this;
+ // otherwise the result is undefined.
+ void ComputeRelative(const TensorSlice& sub, TensorSlice* relative) const;
+
+ // Returns true if the length field was specified in an Extent.
+ static bool HasExtentLength(const TensorSliceProto::Extent& extent);
+
+ // Returns the value of the length field in an Extent, or -1 if it
+ // is not present.
+ static int64 GetExtentLength(const TensorSliceProto::Extent& extent);
+
+ private:
+ // a length value of kFullExtent (-1) means we have a full slice at this
+ // dimension. It's defined in tensor_slice.cc.
+ static const int kFullExtent;
+
+ // TODO(yangke): switch to Eigen once it supports variable size arrays.
+ // A value of
+ gtl::InlinedVector<int, 4> starts_;
+ gtl::InlinedVector<int, 4> lengths_;
+};
+
+template <int NDIMS>
+void TensorSlice::FillIndicesAndSizes(
+ const TensorShape& shape, Eigen::DSizes<ptrdiff_t, NDIMS>* indices,
+ Eigen::DSizes<ptrdiff_t, NDIMS>* sizes) const {
+ CHECK_EQ(shape.dims(), dims()) << "Incompatible dimensions between shape "
+ << "slices: shape = " << shape.DebugString()
+ << ", slice = " << DebugString();
+ CHECK_GE(NDIMS, dims()) << "Asking for a " << NDIMS << "-dim slice from "
+ << "a slice of dimension " << dims();
+ for (int d = 0; d < dims(); ++d) {
+ if (IsFullAt(d)) {
+ (*indices)[d] = 0;
+ (*sizes)[d] = shape.dim_size(d);
+ } else {
+ (*indices)[d] = starts_[d];
+ (*sizes)[d] = lengths_[d];
+ }
+ }
+ for (int d = dims(); d < NDIMS; ++d) {
+ (*indices)[d] = 0;
+ (*sizes)[d] = 1;
+ }
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_