aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-27 10:36:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-27 11:42:39 -0700
commit57b4b9b401f5600d58e040664471a2cdcbcc818f (patch)
treed9de8c5b86abb41e57a9d62126e7805271162b2e /tensorflow/stream_executor/dnn.cc
parentda0433e0c1b488a58ff2c993e9b1a048c09abd5c (diff)
Generalize the stream executor interface to support N-d operations.
Change: 120936645
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r--tensorflow/stream_executor/dnn.cc251
1 files changed, 186 insertions, 65 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 4dc60aea53..41f3bf83d3 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -103,34 +103,115 @@ string ShortPoolingModeString(PoolingMode mode) {
}
}
+std::tuple<int, int, int> GetDimIndices(const DataLayout& layout,
+ const int data_dims) {
+ int depth_idx, batch_idx, spatial_idx;
+ switch (layout) {
+ case DataLayout::kYXBatchDepth:
+ depth_idx = data_dims - 1;
+ batch_idx = data_dims - 2;
+ spatial_idx = 0;
+ break;
+
+ case DataLayout::kYXDepthBatch:
+ depth_idx = data_dims - 2;
+ batch_idx = data_dims - 1;
+ spatial_idx = 0;
+ break;
+
+ case DataLayout::kBatchYXDepth:
+ depth_idx = data_dims - 1;
+ batch_idx = 0;
+ spatial_idx = 1;
+ break;
+
+ case DataLayout::kBatchDepthYX:
+ depth_idx = 1;
+ batch_idx = 0;
+ spatial_idx = 2;
+ break;
+ }
+
+ return std::make_tuple(depth_idx, batch_idx, spatial_idx);
+}
+
+std::vector<int64> ReorderDims(const std::vector<int64>& input,
+ const DataLayout& from, const DataLayout& to) {
+ if (from == to) return input;
+
+ int d_idx_from, b_idx_from, spatial_idx_from;
+ int d_idx_to, b_idx_to, spatial_idx_to;
+
+ std::tie(d_idx_from, b_idx_from, spatial_idx_from) =
+ GetDimIndices(from, input.size());
+ std::tie(d_idx_to, b_idx_to, spatial_idx_to) =
+ GetDimIndices(to, input.size());
+
+ std::vector<int64> reordered(input.size());
+ reordered[b_idx_to] = input[b_idx_from];
+ reordered[d_idx_to] = input[d_idx_from];
+
+ for (int i = 0; i < input.size() - 2;
+ i++, spatial_idx_from++, spatial_idx_to++) {
+ reordered[spatial_idx_to] = input[spatial_idx_from];
+ }
+
+ return reordered;
+}
+
// -- BatchDescriptor
-BatchDescriptor::BatchDescriptor()
+BatchDescriptor::BatchDescriptor(int ndims)
: count_(0),
feature_map_count_(0),
- height_(0),
- width_(0),
+ spatial_size_(ndims, 0),
value_max_(0.0),
value_min_(0.0),
layout_(DataLayout::kYXDepthBatch),
+ ndims_(ndims),
quantized_activation_mode_(QuantizedActivationMode::k8Bit) {}
+BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {}
+
+std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const {
+ std::vector<int64> bdyx_dims(ndims_ + 2);
+ bdyx_dims[0] = count();
+ bdyx_dims[1] = feature_map_count();
+ std::copy(spatial_size_.begin(), spatial_size_.end(), bdyx_dims.begin() + 2);
+ return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout);
+}
+
+std::vector<int64> BatchDescriptor::full_strides(
+ const DataLayout& layout) const {
+ std::vector<int64> phys_dims = full_dims(layout_);
+ std::vector<int64> phys_strides(phys_dims.size());
+ phys_strides[ndims_ + 1] = 1;
+ for (int i = ndims_; i >= 0; i--) {
+ phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1];
+ }
+ return ReorderDims(phys_strides, layout_, layout);
+}
+
void BatchDescriptor::CloneFrom(const BatchDescriptor& other) {
count_ = other.count_;
feature_map_count_ = other.feature_map_count_;
- height_ = other.height_;
- width_ = other.width_;
+ spatial_size_ = other.spatial_size_;
value_max_ = other.value_max_;
value_min_ = other.value_min_;
layout_ = other.layout_;
+ ndims_ = other.ndims_;
quantized_activation_mode_ = other.quantized_activation_mode_;
}
string BatchDescriptor::ToString() const {
+ string spatial;
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&spatial, "%lld ", spatial_size_[i]);
+ }
return port::Printf(
- "{count: %lld feature_map_count: %lld height: %lld width: %lld "
+ "{count: %lld feature_map_count: %lld spatial: %s "
"value_min: %f value_max: %f layout: %s}",
- count_, feature_map_count_, height_, width_, value_min_, value_max_,
+ count_, feature_map_count_, spatial.c_str(), value_min_, value_max_,
DataLayoutString(layout_).c_str());
}
@@ -138,11 +219,14 @@ string BatchDescriptor::ToShortString() const {
// All the constituent strings are less than 15 characters, so the
// small string optimization ensures that there will be at most one
// heap memory allocation.
- string x = port::StrCat("x", width());
- string y = port::StrCat("y", height());
string depth = port::StrCat("d", feature_map_count());
string batch = port::StrCat("b", count());
+ string spatial = "s";
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&spatial, "%lld ", spatial_size_[i]);
+ }
+
string suffix;
if (value_min() != value_max()) {
port::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]");
@@ -153,27 +237,33 @@ string BatchDescriptor::ToShortString() const {
switch (layout()) {
case DataLayout::kYXDepthBatch:
- return port::StrCat(y, x, depth, batch, suffix);
+ return port::StrCat(spatial, depth, batch, suffix);
case DataLayout::kYXBatchDepth:
- return port::StrCat(y, x, batch, depth, suffix);
+ return port::StrCat(spatial, batch, depth, suffix);
case DataLayout::kBatchYXDepth:
- return port::StrCat(batch, y, x, depth, suffix);
+ return port::StrCat(batch, spatial, depth, suffix);
case DataLayout::kBatchDepthYX:
- return port::StrCat(batch, depth, y, x, suffix);
+ return port::StrCat(batch, depth, spatial, suffix);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
return ""; // Avoid return warning (unreachable)
}
}
-int64 BatchDescriptor::NodesPerFeatureMap() const { return width_ * height_; }
+int64 BatchDescriptor::NodesPerFeatureMap() const {
+ int64 ret = 1;
+ for (int i = 0; i < ndims_; i++) {
+ ret *= spatial_size_[i];
+ }
+ return ret;
+}
int64 BatchDescriptor::NodesAcrossFeatureMaps() const {
return NodesPerFeatureMap() * feature_map_count_;
}
int64 BatchDescriptor::ElementCount() const {
- return count_ * feature_map_count_ * height_ * width_;
+ return count_ * feature_map_count_ * NodesPerFeatureMap();
}
int64 BatchDescriptor::FullyConnectedWeightCount(
@@ -201,29 +291,37 @@ BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
// -- FilterDescriptor
-FilterDescriptor::FilterDescriptor()
+FilterDescriptor::FilterDescriptor(int ndims)
: output_feature_map_count_(0),
input_feature_map_count_(0),
- input_filter_height_(0),
- input_filter_width_(0),
+ input_filter_dims_(ndims, 0),
+ ndims_(ndims),
layout_(FilterLayout::kOutputInputYX) {}
+FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {}
+
FilterDescriptor::~FilterDescriptor() {}
void FilterDescriptor::CloneFrom(const FilterDescriptor& other) {
set_output_feature_map_count(other.output_feature_map_count())
.set_input_feature_map_count(other.input_feature_map_count())
- .set_input_filter_height(other.input_filter_height())
- .set_input_filter_width(other.input_filter_width())
.set_layout(other.layout());
+ input_filter_dims_ = other.input_filter_dims_;
+ ndims_ = other.ndims_;
}
string FilterDescriptor::ToString() const {
- return port::Printf(
+ string desc = port::Printf(
"{output_feature_map_count: %lld input_feature_map_count: %lld "
- "input_filter_height: %lld input_filter_width: %lld layout: %s}",
- output_feature_map_count_, input_feature_map_count_, input_filter_height_,
- input_filter_width_, FilterLayoutString(layout_).c_str());
+ "layout: %s shape: ",
+ output_feature_map_count_, input_feature_map_count_,
+ FilterLayoutString(layout_).c_str());
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&desc, "%lld ", input_filter_dims_[i]);
+ }
+ port::StrAppend(&desc, "}");
+
+ return desc;
}
string FilterDescriptor::ToShortString() const {
@@ -232,16 +330,19 @@ string FilterDescriptor::ToShortString() const {
// heap memory allocation.
string od = port::StrCat("od", output_feature_map_count_);
string id = port::StrCat("id", input_feature_map_count_);
- string y = port::StrCat("y", input_filter_height_);
- string x = port::StrCat("x", input_filter_width_);
+
+ string spatial = "s";
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&spatial, "%lld ", input_filter_dims_[i]);
+ }
switch (layout_) {
case FilterLayout::kOutputInputYX:
- return port::StrCat(od, id, y, x);
+ return port::StrCat(od, id, spatial);
case FilterLayout::kInputYXOutput:
- return port::StrCat(id, y, x, od);
+ return port::StrCat(id, spatial, od);
case FilterLayout::kYXInputOutput:
- return port::StrCat(y, x, id, od);
+ return port::StrCat(spatial, id, od);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_);
return ""; // Avoid return warning (unreachable)
@@ -249,71 +350,91 @@ string FilterDescriptor::ToShortString() const {
}
int64 FilterDescriptor::ComputeWeightCount() const {
- return output_feature_map_count_ * input_feature_map_count_ *
- input_filter_height_ * input_filter_width_;
+ int64 ret = output_feature_map_count_ * input_feature_map_count_;
+ for (int i = 0; i < ndims_; i++) {
+ ret *= input_filter_dims_[i];
+ }
+ return ret;
}
// -- ConvolutionDescriptor
+ConvolutionDescriptor::ConvolutionDescriptor(int ndims)
+ : zero_padding_(ndims, 0), filter_strides_(ndims, 1), ndims_(ndims) {}
+
ConvolutionDescriptor::ConvolutionDescriptor()
- : zero_padding_height_(0),
- zero_padding_width_(0),
- vertical_filter_stride_(1),
- horizontal_filter_stride_(1) {}
+ : ConvolutionDescriptor(/*ndims=*/2) {}
ConvolutionDescriptor::~ConvolutionDescriptor() {}
string ConvolutionDescriptor::ToString() const {
- return port::Printf(
- "{zero_padding_height: %lld zero_padding_width: %lld "
- "vertical_filter_stride: %lld horizontal_filter_stride: %lld}",
- zero_padding_height_, zero_padding_width_, vertical_filter_stride_,
- horizontal_filter_stride_);
+ string padding;
+ string strides;
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&padding, "%lld ", zero_padding_[i]);
+ port::Appendf(&strides, "%lld ", filter_strides_[i]);
+ }
+
+ return port::Printf("{zero_padding: %s filter_strides: %s}", padding.c_str(),
+ strides.c_str());
}
string ConvolutionDescriptor::ToShortString() const {
- return port::StrCat("py:", zero_padding_height_, "_px:", zero_padding_width_,
- "_sy:", vertical_filter_stride_, "_sx:",
- horizontal_filter_stride_);
+ string desc;
+ for (int i = 0; i < ndims_; i++) {
+ if (i > 0) port::Appendf(&desc, "_");
+ port::Appendf(&desc, "p%d:%lld", i, zero_padding_[i]);
+ }
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&desc, "_s%d:%lld", i, filter_strides_[i]);
+ }
+ return desc;
}
// -- PoolingDescriptor
-PoolingDescriptor::PoolingDescriptor()
+PoolingDescriptor::PoolingDescriptor(int ndims)
: mode_(dnn::PoolingMode::kMaximum),
- window_height_(0),
- window_width_(0),
- vertical_padding_(0),
- horizontal_padding_(0),
- vertical_stride_(0),
- horizontal_stride_(0) {}
+ ndims_(ndims),
+ window_(ndims, 0),
+ padding_(ndims, 0),
+ strides_(ndims, 1) {}
+
+PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {}
void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) {
mode_ = other.mode_;
- window_height_ = other.window_height_;
- window_width_ = other.window_width_;
- vertical_padding_ = other.vertical_padding_;
- horizontal_padding_ = other.horizontal_padding_;
- vertical_stride_ = other.vertical_stride_;
- horizontal_stride_ = other.horizontal_stride_;
+ ndims_ = other.ndims_;
+ window_ = other.window_;
+ padding_ = other.padding_;
+ strides_ = other.strides_;
}
string PoolingDescriptor::ToString() const {
const char* mode_string =
mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage";
- return port::Printf(
- "{mode: %s window_height: %lld window_width: %lld vertical_stride: %lld "
- "horizontal_stride: %lld vertical padding: %lld horizontal padding: "
- "%lld}",
- mode_string, window_height_, window_width_, vertical_stride_,
- horizontal_stride_, vertical_padding_, horizontal_padding_);
+
+ string window, strides, padding;
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&window, "%lld ", window_[i]);
+ port::Appendf(&strides, "%lld ", strides_[i]);
+ port::Appendf(&padding, "%lld", padding_[i]);
+ }
+
+ return port::Printf("{mode: %s window: %s strides: %s padding: %s}",
+ mode_string, window.c_str(), strides.c_str(),
+ padding.c_str());
}
string PoolingDescriptor::ToShortString() const {
+ string window, strides, padding;
+ for (int i = 0; i < ndims_; i++) {
+ port::Appendf(&window, "_w%d:%lld", i, window_[i]);
+ port::Appendf(&strides, "_s%d:%lld", i, strides_[i]);
+ port::Appendf(&padding, "_p%d:%lld", i, padding_[i]);
+ }
return port::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg",
- "_y:", window_height_, "_x:", window_width_, "_py:",
- vertical_padding_, "_px:", horizontal_padding_, "_sy:",
- vertical_stride_, "_sx:", horizontal_stride_);
+ window, strides, padding);
}
// -- NormalizeDescriptor