aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/tensor_util.h
blob: 44513fe79f19515ed41f7daac83abf942e94a70e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_

#include "tensorflow/core/public/tensor.h"

namespace tensorflow {
namespace tensor {

// DeepCopy returns a tensor whose contents are a deep copy of the
// contents of 'other'.  This function is intended only for
// convenience, not speed.
//
// REQUIRES: 'other' must point to data stored in CPU memory.
// REQUIRES: 'other' must be a Tensor of a copy-able type if
//           'other' is not appropriately memory-aligned.
Tensor DeepCopy(const Tensor& other);

// Concatenates 'tensors' into a single tensor, along their 0th dimension.
//
// REQUIRES: All members of 'tensors' must have the same data type parameter.
// REQUIRES: Each member of 'tensors' must have at least one dimension.
// REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
// REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
//           is not appropriately memory-aligned.
Tensor Concat(const gtl::ArraySlice<Tensor>& tensors);

// Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
// dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
//
// REQUIRES: 'tensor' must have at least one dimension.
// REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
// REQUIRES: 'tensor' must point to data stored in CPU memory.
// REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
//           appropriately memory-aligned.
//
// Split() and Concat() are inverse operations.
std::vector<Tensor> Split(const Tensor& tensor,
                          const gtl::ArraySlice<int64>& sizes);

}  // namespace tensor
}  // namespace tensorflow

#endif  // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_