diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-04-27 10:36:19 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-04-27 11:42:39 -0700 |
commit | 57b4b9b401f5600d58e040664471a2cdcbcc818f (patch) | |
tree | d9de8c5b86abb41e57a9d62126e7805271162b2e /tensorflow/stream_executor/dnn.cc | |
parent | da0433e0c1b488a58ff2c993e9b1a048c09abd5c (diff) |
Generalize the stream executor interface to support N-d operations.
Change: 120936645
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/dnn.cc | 251 |
1 files changed, 186 insertions, 65 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 4dc60aea53..41f3bf83d3 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -103,34 +103,115 @@ string ShortPoolingModeString(PoolingMode mode) { } } +std::tuple<int, int, int> GetDimIndices(const DataLayout& layout, + const int data_dims) { + int depth_idx, batch_idx, spatial_idx; + switch (layout) { + case DataLayout::kYXBatchDepth: + depth_idx = data_dims - 1; + batch_idx = data_dims - 2; + spatial_idx = 0; + break; + + case DataLayout::kYXDepthBatch: + depth_idx = data_dims - 2; + batch_idx = data_dims - 1; + spatial_idx = 0; + break; + + case DataLayout::kBatchYXDepth: + depth_idx = data_dims - 1; + batch_idx = 0; + spatial_idx = 1; + break; + + case DataLayout::kBatchDepthYX: + depth_idx = 1; + batch_idx = 0; + spatial_idx = 2; + break; + } + + return std::make_tuple(depth_idx, batch_idx, spatial_idx); +} + +std::vector<int64> ReorderDims(const std::vector<int64>& input, + const DataLayout& from, const DataLayout& to) { + if (from == to) return input; + + int d_idx_from, b_idx_from, spatial_idx_from; + int d_idx_to, b_idx_to, spatial_idx_to; + + std::tie(d_idx_from, b_idx_from, spatial_idx_from) = + GetDimIndices(from, input.size()); + std::tie(d_idx_to, b_idx_to, spatial_idx_to) = + GetDimIndices(to, input.size()); + + std::vector<int64> reordered(input.size()); + reordered[b_idx_to] = input[b_idx_from]; + reordered[d_idx_to] = input[d_idx_from]; + + for (int i = 0; i < input.size() - 2; + i++, spatial_idx_from++, spatial_idx_to++) { + reordered[spatial_idx_to] = input[spatial_idx_from]; + } + + return reordered; +} + // -- BatchDescriptor -BatchDescriptor::BatchDescriptor() +BatchDescriptor::BatchDescriptor(int ndims) : count_(0), feature_map_count_(0), - height_(0), - width_(0), + spatial_size_(ndims, 0), value_max_(0.0), value_min_(0.0), layout_(DataLayout::kYXDepthBatch), + ndims_(ndims), quantized_activation_mode_(QuantizedActivationMode::k8Bit) {} +BatchDescriptor::BatchDescriptor() : BatchDescriptor(/*ndims=*/2) {} + +std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const { + std::vector<int64> bdyx_dims(ndims_ + 2); + bdyx_dims[0] = count(); + bdyx_dims[1] = feature_map_count(); + std::copy(spatial_size_.begin(), spatial_size_.end(), bdyx_dims.begin() + 2); + return ReorderDims(bdyx_dims, DataLayout::kBatchDepthYX, layout); +} + +std::vector<int64> BatchDescriptor::full_strides( + const DataLayout& layout) const { + std::vector<int64> phys_dims = full_dims(layout_); + std::vector<int64> phys_strides(phys_dims.size()); + phys_strides[ndims_ + 1] = 1; + for (int i = ndims_; i >= 0; i--) { + phys_strides[i] = phys_strides[i + 1] * phys_dims[i + 1]; + } + return ReorderDims(phys_strides, layout_, layout); +} + void BatchDescriptor::CloneFrom(const BatchDescriptor& other) { count_ = other.count_; feature_map_count_ = other.feature_map_count_; - height_ = other.height_; - width_ = other.width_; + spatial_size_ = other.spatial_size_; value_max_ = other.value_max_; value_min_ = other.value_min_; layout_ = other.layout_; + ndims_ = other.ndims_; quantized_activation_mode_ = other.quantized_activation_mode_; } string BatchDescriptor::ToString() const { + string spatial; + for (int i = 0; i < ndims_; i++) { + port::Appendf(&spatial, "%lld ", spatial_size_[i]); + } return port::Printf( - "{count: %lld feature_map_count: %lld height: %lld width: %lld " + "{count: %lld feature_map_count: %lld spatial: %s " "value_min: %f value_max: %f layout: %s}", - count_, feature_map_count_, height_, width_, value_min_, value_max_, + count_, feature_map_count_, spatial.c_str(), value_min_, value_max_, DataLayoutString(layout_).c_str()); } @@ -138,11 +219,14 @@ string BatchDescriptor::ToShortString() const { // All the constituent strings are less than 15 characters, so the // small string optimization ensures that there will be at most one // heap memory allocation. - string x = port::StrCat("x", width()); - string y = port::StrCat("y", height()); string depth = port::StrCat("d", feature_map_count()); string batch = port::StrCat("b", count()); + string spatial = "s"; + for (int i = 0; i < ndims_; i++) { + port::Appendf(&spatial, "%lld ", spatial_size_[i]); + } + string suffix; if (value_min() != value_max()) { port::StrAppend(&suffix, "[", value_min(), ";", value_max(), "]"); @@ -153,27 +237,33 @@ string BatchDescriptor::ToShortString() const { switch (layout()) { case DataLayout::kYXDepthBatch: - return port::StrCat(y, x, depth, batch, suffix); + return port::StrCat(spatial, depth, batch, suffix); case DataLayout::kYXBatchDepth: - return port::StrCat(y, x, batch, depth, suffix); + return port::StrCat(spatial, batch, depth, suffix); case DataLayout::kBatchYXDepth: - return port::StrCat(batch, y, x, depth, suffix); + return port::StrCat(batch, spatial, depth, suffix); case DataLayout::kBatchDepthYX: - return port::StrCat(batch, depth, y, x, suffix); + return port::StrCat(batch, depth, spatial, suffix); default: LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout()); return ""; // Avoid return warning (unreachable) } } -int64 BatchDescriptor::NodesPerFeatureMap() const { return width_ * height_; } +int64 BatchDescriptor::NodesPerFeatureMap() const { + int64 ret = 1; + for (int i = 0; i < ndims_; i++) { + ret *= spatial_size_[i]; + } + return ret; +} int64 BatchDescriptor::NodesAcrossFeatureMaps() const { return NodesPerFeatureMap() * feature_map_count_; } int64 BatchDescriptor::ElementCount() const { - return count_ * feature_map_count_ * height_ * width_; + return count_ * feature_map_count_ * NodesPerFeatureMap(); } int64 BatchDescriptor::FullyConnectedWeightCount( @@ -201,29 +291,37 @@ BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor( // -- FilterDescriptor -FilterDescriptor::FilterDescriptor() +FilterDescriptor::FilterDescriptor(int ndims) : output_feature_map_count_(0), input_feature_map_count_(0), - input_filter_height_(0), - input_filter_width_(0), + input_filter_dims_(ndims, 0), + ndims_(ndims), layout_(FilterLayout::kOutputInputYX) {} +FilterDescriptor::FilterDescriptor() : FilterDescriptor(/*ndims=*/2) {} + FilterDescriptor::~FilterDescriptor() {} void FilterDescriptor::CloneFrom(const FilterDescriptor& other) { set_output_feature_map_count(other.output_feature_map_count()) .set_input_feature_map_count(other.input_feature_map_count()) - .set_input_filter_height(other.input_filter_height()) - .set_input_filter_width(other.input_filter_width()) .set_layout(other.layout()); + input_filter_dims_ = other.input_filter_dims_; + ndims_ = other.ndims_; } string FilterDescriptor::ToString() const { - return port::Printf( + string desc = port::Printf( "{output_feature_map_count: %lld input_feature_map_count: %lld " - "input_filter_height: %lld input_filter_width: %lld layout: %s}", - output_feature_map_count_, input_feature_map_count_, input_filter_height_, - input_filter_width_, FilterLayoutString(layout_).c_str()); + "layout: %s shape: ", + output_feature_map_count_, input_feature_map_count_, + FilterLayoutString(layout_).c_str()); + for (int i = 0; i < ndims_; i++) { + port::Appendf(&desc, "%lld ", input_filter_dims_[i]); + } + port::StrAppend(&desc, "}"); + + return desc; } string FilterDescriptor::ToShortString() const { @@ -232,16 +330,19 @@ string FilterDescriptor::ToShortString() const { // heap memory allocation. string od = port::StrCat("od", output_feature_map_count_); string id = port::StrCat("id", input_feature_map_count_); - string y = port::StrCat("y", input_filter_height_); - string x = port::StrCat("x", input_filter_width_); + + string spatial = "s"; + for (int i = 0; i < ndims_; i++) { + port::Appendf(&spatial, "%lld ", input_filter_dims_[i]); + } switch (layout_) { case FilterLayout::kOutputInputYX: - return port::StrCat(od, id, y, x); + return port::StrCat(od, id, spatial); case FilterLayout::kInputYXOutput: - return port::StrCat(id, y, x, od); + return port::StrCat(id, spatial, od); case FilterLayout::kYXInputOutput: - return port::StrCat(y, x, id, od); + return port::StrCat(spatial, id, od); default: LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_); return ""; // Avoid return warning (unreachable) @@ -249,71 +350,91 @@ string FilterDescriptor::ToShortString() const { } int64 FilterDescriptor::ComputeWeightCount() const { - return output_feature_map_count_ * input_feature_map_count_ * - input_filter_height_ * input_filter_width_; + int64 ret = output_feature_map_count_ * input_feature_map_count_; + for (int i = 0; i < ndims_; i++) { + ret *= input_filter_dims_[i]; + } + return ret; } // -- ConvolutionDescriptor +ConvolutionDescriptor::ConvolutionDescriptor(int ndims) + : zero_padding_(ndims, 0), filter_strides_(ndims, 1), ndims_(ndims) {} + ConvolutionDescriptor::ConvolutionDescriptor() - : zero_padding_height_(0), - zero_padding_width_(0), - vertical_filter_stride_(1), - horizontal_filter_stride_(1) {} + : ConvolutionDescriptor(/*ndims=*/2) {} ConvolutionDescriptor::~ConvolutionDescriptor() {} string ConvolutionDescriptor::ToString() const { - return port::Printf( - "{zero_padding_height: %lld zero_padding_width: %lld " - "vertical_filter_stride: %lld horizontal_filter_stride: %lld}", - zero_padding_height_, zero_padding_width_, vertical_filter_stride_, - horizontal_filter_stride_); + string padding; + string strides; + for (int i = 0; i < ndims_; i++) { + port::Appendf(&padding, "%lld ", zero_padding_[i]); + port::Appendf(&strides, "%lld ", filter_strides_[i]); + } + + return port::Printf("{zero_padding: %s filter_strides: %s}", padding.c_str(), + strides.c_str()); } string ConvolutionDescriptor::ToShortString() const { - return port::StrCat("py:", zero_padding_height_, "_px:", zero_padding_width_, - "_sy:", vertical_filter_stride_, "_sx:", - horizontal_filter_stride_); + string desc; + for (int i = 0; i < ndims_; i++) { + if (i > 0) port::Appendf(&desc, "_"); + port::Appendf(&desc, "p%d:%lld", i, zero_padding_[i]); + } + for (int i = 0; i < ndims_; i++) { + port::Appendf(&desc, "_s%d:%lld", i, filter_strides_[i]); + } + return desc; } // -- PoolingDescriptor -PoolingDescriptor::PoolingDescriptor() +PoolingDescriptor::PoolingDescriptor(int ndims) : mode_(dnn::PoolingMode::kMaximum), - window_height_(0), - window_width_(0), - vertical_padding_(0), - horizontal_padding_(0), - vertical_stride_(0), - horizontal_stride_(0) {} + ndims_(ndims), + window_(ndims, 0), + padding_(ndims, 0), + strides_(ndims, 1) {} + +PoolingDescriptor::PoolingDescriptor() : PoolingDescriptor(/*ndims=*/2) {} void PoolingDescriptor::CloneFrom(const PoolingDescriptor& other) { mode_ = other.mode_; - window_height_ = other.window_height_; - window_width_ = other.window_width_; - vertical_padding_ = other.vertical_padding_; - horizontal_padding_ = other.horizontal_padding_; - vertical_stride_ = other.vertical_stride_; - horizontal_stride_ = other.horizontal_stride_; + ndims_ = other.ndims_; + window_ = other.window_; + padding_ = other.padding_; + strides_ = other.strides_; } string PoolingDescriptor::ToString() const { const char* mode_string = mode_ == dnn::PoolingMode::kMaximum ? "kMaximum" : "kAverage"; - return port::Printf( - "{mode: %s window_height: %lld window_width: %lld vertical_stride: %lld " - "horizontal_stride: %lld vertical padding: %lld horizontal padding: " - "%lld}", - mode_string, window_height_, window_width_, vertical_stride_, - horizontal_stride_, vertical_padding_, horizontal_padding_); + + string window, strides, padding; + for (int i = 0; i < ndims_; i++) { + port::Appendf(&window, "%lld ", window_[i]); + port::Appendf(&strides, "%lld ", strides_[i]); + port::Appendf(&padding, "%lld", padding_[i]); + } + + return port::Printf("{mode: %s window: %s strides: %s padding: %s}", + mode_string, window.c_str(), strides.c_str(), + padding.c_str()); } string PoolingDescriptor::ToShortString() const { + string window, strides, padding; + for (int i = 0; i < ndims_; i++) { + port::Appendf(&window, "_w%d:%lld", i, window_[i]); + port::Appendf(&strides, "_s%d:%lld", i, strides_[i]); + port::Appendf(&padding, "_p%d:%lld", i, padding_[i]); + } return port::StrCat(mode_ == dnn::PoolingMode::kMaximum ? "max" : "avg", - "_y:", window_height_, "_x:", window_width_, "_py:", - vertical_padding_, "_px:", horizontal_padding_, "_sy:", - vertical_stride_, "_sx:", horizontal_stride_); + window, strides, padding); } // -- NormalizeDescriptor |