diff options
author | 2017-05-17 15:38:59 -0700 | |
---|---|---|
committer | 2017-05-17 15:45:11 -0700 | |
commit | 9d9aa78c105a2414deed6ced31fb7699f08e55bb (patch) | |
tree | 794cea0ef4ad24a2a2c3a370a6d376c30fdb94b9 /tensorflow/compiler/xla/array4d.h | |
parent | ebcc089d41598d59fd93ae2a2aac1203571545c4 (diff) |
[XLA] Fix bool support for Array2D/Array3D/Array4D.
The Array?D<T> classes were previously backed by a std::vector<T>, which has an inconsistent bitset implementation for T == bool. This meant that the bool instantiations of Array?D caused compile-time errors. Change the Array classes to be backed by a std::unique_ptr<T[]> instead, since the dimensions of an Array are constant.
PiperOrigin-RevId: 156364397
Diffstat (limited to 'tensorflow/compiler/xla/array4d.h')
-rw-r--r-- | tensorflow/compiler/xla/array4d.h | 44 |
1 files changed, 31 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index 199ad2baae..1c6ba1f519 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -20,6 +20,7 @@ limitations under the License. #include <functional> #include <initializer_list> #include <iterator> +#include <memory> #include <numeric> #include <random> #include <string> @@ -60,15 +61,15 @@ class Array4D { depth_(depth), height_(height), width_(width), - values_(planes * depth * height * width) {} + values_(new T[planes * depth * height * width]) { + Fill(T()); + } // Creates a 4D array, initalized to value. Array4D(int64 planes, int64 depth, int64 height, int64 width, T value) - : planes_(planes), - depth_(depth), - height_(height), - width_(width), - values_(planes * depth * height * width, value) {} + : Array4D(planes, depth, height, width) { + Fill(value); + } // Creates a 4D array, filled with values. // @@ -111,6 +112,23 @@ class Array4D { } } + Array4D(const Array4D<T>& other) + : Array4D(other.planes(), other.depth(), other.height(), other.width()) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + } + + Array4D<T>& operator=(const Array4D<T>& other) { + planes_ = other.planes(); + depth_ = other.depth(); + height_ = other.height(); + width_ = other.width(); + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + T& operator()(int64 plane, int64 depth, int64 height, int64 width) { CHECK_LT(plane, planes_); CHECK_LT(depth, depth_); @@ -135,24 +153,24 @@ class Array4D { int64 n3() const { return height_; } int64 n2() const { return depth_; } int64 n1() const { return planes_; } - int64 num_elements() const { return values_.size(); } + int64 num_elements() const { return width_ * height_ * depth_ * planes_; } // Sets all the values in the array to values. template <typename Container = std::initializer_list<T>> void SetValues(const Container& container) { CHECK_EQ(std::distance(std::begin(container), std::end(container)), num_elements()); - values_.assign(std::begin(container), std::end(container)); + std::copy(std::begin(container), std::end(container), &values_[0]); } // Fills the array with the given value. void Fill(const T& value) { - std::fill(values_.begin(), values_.end(), value); + std::fill(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with iota. void FillIota(const T& value) { - std::iota(values_.begin(), values_.end(), value); + std::iota(&values_[0], &values_[0] + num_elements(), value); } // Fills the array with random variable with a deviation of value and a mean @@ -162,8 +180,8 @@ class Array4D { std::mt19937 g(seed); std::normal_distribution<double> distribution(mean, static_cast<double>(value)); - for (auto& v : values_) { - v = static_cast<T>(distribution(g)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast<T>(distribution(g)); } } @@ -268,7 +286,7 @@ class Array4D { int64 depth_; int64 height_; int64 width_; - std::vector<T> values_; + std::unique_ptr<T[]> values_; }; } // namespace xla |