aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/array3d.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-17 17:47:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-17 17:50:58 -0700
commit5f865f703621fed07925b3828f4a731066d98fd6 (patch)
treeb1a1921c764545cf715bae10a7fc1e17a8425b82 /tensorflow/compiler/xla/array3d.h
parent33ce1d06393f773c5317bb38ab996c2a7b8aa429 (diff)
The new class will be used as the base class for the existing 2-4
dimensional array classes to share code as well as for creating higher dimensional arrays. The API of the new class is kept compatible with the previous API to limit the scope of this change. PiperOrigin-RevId: 172543319
Diffstat (limited to 'tensorflow/compiler/xla/array3d.h')
-rw-r--r--tensorflow/compiler/xla/array3d.h94
1 files changed, 8 insertions, 86 deletions
diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h
index 124ccd1975..e9449f01ad 100644
--- a/tensorflow/compiler/xla/array3d.h
+++ b/tensorflow/compiler/xla/array3d.h
@@ -24,6 +24,7 @@ limitations under the License.
#include <numeric>
#include <random>
+#include "tensorflow/compiler/xla/array.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -32,22 +33,16 @@ limitations under the License.
namespace xla {
// Simple 3D array structure.
-//
-// The data layout in major-to-minor order is: n1, n2, n3.
template <typename T>
-class Array3D {
+class Array3D : public Array<T> {
public:
// Creates an array of dimensions n1 x n2 x n3, uninitialized values.
Array3D(const int64 n1, const int64 n2, const int64 n3)
- : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) {
- Fill(T());
- }
+ : Array<T>(std::vector<int64>{n1, n2, n3}) {}
// Creates an array of dimensions n1 x n2 x n3, initialized to value.
Array3D(const int64 n1, const int64 n2, const int64 n3, const T value)
- : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) {
- Fill(value);
- }
+ : Array<T>(std::vector<int64>{n1, n2, n3}, value) {}
// Creates an array from the given nested initializer list. The outer
// initializer list is the first dimension, and so on.
@@ -58,84 +53,11 @@ class Array3D {
// results in an array with n1=3, n2=4, n3=2.
Array3D(std::initializer_list<std::initializer_list<std::initializer_list<T>>>
values)
- : Array3D(values.size(), values.begin()->size(),
- values.begin()->begin()->size()) {
- int64 n1 = 0;
- for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) {
- int64 n2 = 0;
- for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) {
- int64 n3 = 0;
- for (auto n3_it = n2_it->begin(); n3_it != n2_it->end();
- ++n3_it, ++n3) {
- (*this)(n1, n2, n3) = *n3_it;
- }
- }
- }
- }
+ : Array<T>(values) {}
- Array3D(const Array3D<T>& other)
- : Array3D(other.n1(), other.n2(), other.n3()) {
- std::copy(&other.values_[0], &other.values_[0] + num_elements(),
- &values_[0]);
- }
-
- Array3D<T>& operator=(const Array3D<T>& other) {
- n1_ = other.n1();
- n2_ = other.n2();
- n3_ = other.n3();
- values_.reset(new T[num_elements()]);
- std::copy(&other.values_[0], &other.values_[0] + num_elements(),
- &values_[0]);
- return *this;
- }
-
- T& operator()(const int64 i1, const int64 i2, const int64 i3) {
- CHECK_LT(i1, n1_);
- CHECK_LT(i2, n2_);
- CHECK_LT(i3, n3_);
- return values_[i1 * n2_ * n3_ + i2 * n3_ + i3];
- }
-
- const T& operator()(const int64 i1, const int64 i2, const int64 i3) const {
- CHECK_LT(i1, n1_);
- CHECK_LT(i2, n2_);
- CHECK_LT(i3, n3_);
- return values_[i1 * n2_ * n3_ + i2 * n3_ + i3];
- }
-
- // Access to the array's dimensions.
- int64 n1() const { return n1_; }
- int64 n2() const { return n2_; }
- int64 n3() const { return n3_; }
- int64 num_elements() const { return n1_ * n2_ * n3_; }
-
- // Fills the array with the given value.
- void Fill(const T& value) {
- std::fill(&values_[0], &values_[0] + num_elements(), value);
- }
-
- // Fills the array with sequentially increasing values.
- void FillIota(const T& value) {
- std::iota(&values_[0], &values_[0] + num_elements(), value);
- }
-
- // Fills the array with random normal values with a mean of 0 and standard
- // deviation of value.
- void FillRandom(const T& value, const double mean = 0.0,
- const int seed = 12345) {
- std::mt19937 g(seed);
- std::normal_distribution<double> distribution(mean,
- static_cast<double>(value));
- for (int64 i = 0; i < num_elements(); ++i) {
- values_[i] = static_cast<T>(distribution(g));
- }
- }
-
- private:
- int64 n1_;
- int64 n2_;
- int64 n3_;
- std::unique_ptr<T[]> values_;
+ int64 n1() const { return this->dim(0); }
+ int64 n2() const { return this->dim(1); }
+ int64 n3() const { return this->dim(2); }
};
} // namespace xla