diff options
author | Jingyue Wu <jingyue@google.com> | 2017-06-12 21:42:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-12 21:45:49 -0700 |
commit | de702291b0d160ae8ef13448311b666d71c0d47a (patch) | |
tree | 0afdb08b15db69c1417cbfbac95c0263b03575b8 /tensorflow/stream_executor/dnn.cc | |
parent | 4cc0caadf6b6a1b3b589595d6a752802f3d23356 (diff) |
TransformTensor supports NCHW_VECT_C layout and int8 data type.
o Add new DataType kInt8
o Add new DataLayout kBatchDepthYX4 and new FilterLayout kOutputInputYX4, both
of which map to CUDNN_TENSOR_NCHW_VECT_C
o Change (Then|Do)TransformTensor to take input and output element types.
o Add new tensor format FORMAT_NCHW_VECT_C, and change the utility functions in
tensor_format.h to work with the new format.
PiperOrigin-RevId: 158806412
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r-- | tensorflow/stream_executor/dnn.cc | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index ee78066f95..311f45f748 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -94,6 +94,8 @@ string DataLayoutString(DataLayout layout) { return "BatchYXDepth"; case DataLayout::kBatchDepthYX: return "BatchDepthYX"; + case DataLayout::kBatchDepthYX4: + return "BatchDepthYX4"; default: LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout); } @@ -104,6 +106,8 @@ string FilterLayoutString(FilterLayout layout) { switch (layout) { case FilterLayout::kOutputInputYX: return "OutputInputYX"; + case FilterLayout::kOutputInputYX4: + return "OutputInputYX4"; case FilterLayout::kInputYXOutput: return "InputYXOutput"; case FilterLayout::kYXInputOutput: @@ -161,6 +165,7 @@ std::tuple<int, int, int> GetDimIndices(const DataLayout& layout, break; case DataLayout::kBatchDepthYX: + case DataLayout::kBatchDepthYX4: depth_idx = 1; batch_idx = 0; spatial_idx = 2; @@ -224,6 +229,14 @@ std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const { std::vector<int64> BatchDescriptor::full_strides( const DataLayout& layout) const { + if (layout_ == DataLayout::kBatchDepthYX4) { + LOG(FATAL) + << "Cannot compute full strides for batch descriptor " << ToString() + << ", because its layout is kBatchDepthYX4. In fact, " + "cudnnSetTensorNdDescriptor doesn't work for kBatchDepthYX4 at all. " + "Use cudnnSetTensor4DDescriptor to set cudnnTensorDescriptor_t " + "instead."; + } std::vector<int64> phys_dims = full_dims(layout_); std::vector<int64> phys_strides(phys_dims.size()); phys_strides[ndims_ + 1] = 1; @@ -285,6 +298,8 @@ string BatchDescriptor::ToShortString() const { return port::StrCat(batch, spatial, depth, suffix); case DataLayout::kBatchDepthYX: return port::StrCat(batch, depth, spatial, suffix); + case DataLayout::kBatchDepthYX4: + return port::StrCat(batch, depth, spatial, suffix, "(VECT_C)"); default: LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout()); return ""; // Avoid return warning (unreachable) @@ -380,6 +395,8 @@ string FilterDescriptor::ToShortString() const { switch (layout_) { case FilterLayout::kOutputInputYX: return port::StrCat(od, id, spatial); + case FilterLayout::kOutputInputYX4: + return port::StrCat(od, id, spatial, "(VECT_C)"); case FilterLayout::kInputYXOutput: return port::StrCat(id, spatial, od); case FilterLayout::kYXInputOutput: |