aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/array4d.h
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-05-17 15:38:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-17 15:45:11 -0700
commit9d9aa78c105a2414deed6ced31fb7699f08e55bb (patch)
tree794cea0ef4ad24a2a2c3a370a6d376c30fdb94b9 /tensorflow/compiler/xla/array4d.h
parentebcc089d41598d59fd93ae2a2aac1203571545c4 (diff)
[XLA] Fix bool support for Array2D/Array3D/Array4D.
The Array?D<T> classes were previously backed by a std::vector<T>, which has an inconsistent bitset implementation for T == bool. This meant that the bool instantiations of Array?D caused compile-time errors. Change the Array classes to be backed by a std::unique_ptr<T[]> instead, since the dimensions of an Array are constant. PiperOrigin-RevId: 156364397
Diffstat (limited to 'tensorflow/compiler/xla/array4d.h')
-rw-r--r--tensorflow/compiler/xla/array4d.h44
1 files changed, 31 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h
index 199ad2baae..1c6ba1f519 100644
--- a/tensorflow/compiler/xla/array4d.h
+++ b/tensorflow/compiler/xla/array4d.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <functional>
#include <initializer_list>
#include <iterator>
+#include <memory>
#include <numeric>
#include <random>
#include <string>
@@ -60,15 +61,15 @@ class Array4D {
depth_(depth),
height_(height),
width_(width),
- values_(planes * depth * height * width) {}
+ values_(new T[planes * depth * height * width]) {
+ Fill(T());
+ }
// Creates a 4D array, initalized to value.
Array4D(int64 planes, int64 depth, int64 height, int64 width, T value)
- : planes_(planes),
- depth_(depth),
- height_(height),
- width_(width),
- values_(planes * depth * height * width, value) {}
+ : Array4D(planes, depth, height, width) {
+ Fill(value);
+ }
// Creates a 4D array, filled with values.
//
@@ -111,6 +112,23 @@ class Array4D {
}
}
+ Array4D(const Array4D<T>& other)
+ : Array4D(other.planes(), other.depth(), other.height(), other.width()) {
+ std::copy(&other.values_[0], &other.values_[0] + num_elements(),
+ &values_[0]);
+ }
+
+ Array4D<T>& operator=(const Array4D<T>& other) {
+ planes_ = other.planes();
+ depth_ = other.depth();
+ height_ = other.height();
+ width_ = other.width();
+ values_.reset(new T[num_elements()]);
+ std::copy(&other.values_[0], &other.values_[0] + num_elements(),
+ &values_[0]);
+ return *this;
+ }
+
T& operator()(int64 plane, int64 depth, int64 height, int64 width) {
CHECK_LT(plane, planes_);
CHECK_LT(depth, depth_);
@@ -135,24 +153,24 @@ class Array4D {
int64 n3() const { return height_; }
int64 n2() const { return depth_; }
int64 n1() const { return planes_; }
- int64 num_elements() const { return values_.size(); }
+ int64 num_elements() const { return width_ * height_ * depth_ * planes_; }
// Sets all the values in the array to values.
template <typename Container = std::initializer_list<T>>
void SetValues(const Container& container) {
CHECK_EQ(std::distance(std::begin(container), std::end(container)),
num_elements());
- values_.assign(std::begin(container), std::end(container));
+ std::copy(std::begin(container), std::end(container), &values_[0]);
}
// Fills the array with the given value.
void Fill(const T& value) {
- std::fill(values_.begin(), values_.end(), value);
+ std::fill(&values_[0], &values_[0] + num_elements(), value);
}
// Fills the array with iota.
void FillIota(const T& value) {
- std::iota(values_.begin(), values_.end(), value);
+ std::iota(&values_[0], &values_[0] + num_elements(), value);
}
// Fills the array with random variable with a deviation of value and a mean
@@ -162,8 +180,8 @@ class Array4D {
std::mt19937 g(seed);
std::normal_distribution<double> distribution(mean,
static_cast<double>(value));
- for (auto& v : values_) {
- v = static_cast<T>(distribution(g));
+ for (int64 i = 0; i < num_elements(); ++i) {
+ values_[i] = static_cast<T>(distribution(g));
}
}
@@ -268,7 +286,7 @@ class Array4D {
int64 depth_;
int64 height_;
int64 width_;
- std::vector<T> values_;
+ std::unique_ptr<T[]> values_;
};
} // namespace xla