diff options
author | 2018-09-24 11:04:14 -0700 | |
---|---|---|
committer | 2018-09-24 11:11:38 -0700 | |
commit | 3bb3257a5f9675e6c094b9a6318d96d1bc27fc94 (patch) | |
tree | 730f2680227453abd863c1ed2661cab32635e167 /tensorflow/core/framework/tensor.cc | |
parent | f7017ef769bd603b61f25dfffc772e2153a9f076 (diff) |
Add functionality to SubSlice a tensor.
PiperOrigin-RevId: 214295534
Diffstat (limited to 'tensorflow/core/framework/tensor.cc')
-rw-r--r-- | tensorflow/core/framework/tensor.cc | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index eb9c79ff2d..3df677675e 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -812,6 +812,28 @@ Tensor Tensor::Slice(int64 start, int64 limit) const { return ret; } +Tensor Tensor::SubSlice(int64 index) const { + CHECK_GE(dims(), 2); // Crash ok. + CHECK_LE(0, index); // Crash ok. + int64 dim0_size = shape_.dim_size(0); + CHECK_LE(index, dim0_size); // Crash ok. + Tensor ret; + ret.shape_ = shape_; + ret.shape_.RemoveDim(0); + ret.set_dtype(dtype()); + ret.buf_ = nullptr; + if (dim0_size > 0) { + const int64 elems_per_dim0 = NumElements() / dim0_size; + const int64 delta = index * elems_per_dim0; + const int64 num_elems = elems_per_dim0; + if (buf_) { + DataType dt = dtype(); + CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems)); + } + } + return ret; +} + bool Tensor::FromProto(const TensorProto& proto) { return FromProto(cpu_allocator(), proto); } |