aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/dnn.cc
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2017-06-12 21:42:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-12 21:45:49 -0700
commitde702291b0d160ae8ef13448311b666d71c0d47a (patch)
tree0afdb08b15db69c1417cbfbac95c0263b03575b8 /tensorflow/stream_executor/dnn.cc
parent4cc0caadf6b6a1b3b589595d6a752802f3d23356 (diff)
TransformTensor supports NCHW_VECT_C layout and int8 data type.
o Add new DataType kInt8 o Add new DataLayout kBatchDepthYX4 and new FilterLayout kOutputInputYX4, both of which map to CUDNN_TENSOR_NCHW_VECT_C o Change (Then|Do)TransformTensor to take input and output element types. o Add new tensor format FORMAT_NCHW_VECT_C, and change the utility functions in tensor_format.h to work with the new format. PiperOrigin-RevId: 158806412
Diffstat (limited to 'tensorflow/stream_executor/dnn.cc')
-rw-r--r--tensorflow/stream_executor/dnn.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index ee78066f95..311f45f748 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -94,6 +94,8 @@ string DataLayoutString(DataLayout layout) {
return "BatchYXDepth";
case DataLayout::kBatchDepthYX:
return "BatchDepthYX";
+ case DataLayout::kBatchDepthYX4:
+ return "BatchDepthYX4";
default:
LOG(FATAL) << "Unknown data layout " << static_cast<int32>(layout);
}
@@ -104,6 +106,8 @@ string FilterLayoutString(FilterLayout layout) {
switch (layout) {
case FilterLayout::kOutputInputYX:
return "OutputInputYX";
+ case FilterLayout::kOutputInputYX4:
+ return "OutputInputYX4";
case FilterLayout::kInputYXOutput:
return "InputYXOutput";
case FilterLayout::kYXInputOutput:
@@ -161,6 +165,7 @@ std::tuple<int, int, int> GetDimIndices(const DataLayout& layout,
break;
case DataLayout::kBatchDepthYX:
+ case DataLayout::kBatchDepthYX4:
depth_idx = 1;
batch_idx = 0;
spatial_idx = 2;
@@ -224,6 +229,14 @@ std::vector<int64> BatchDescriptor::full_dims(const DataLayout& layout) const {
std::vector<int64> BatchDescriptor::full_strides(
const DataLayout& layout) const {
+ if (layout_ == DataLayout::kBatchDepthYX4) {
+ LOG(FATAL)
+ << "Cannot compute full strides for batch descriptor " << ToString()
+ << ", because its layout is kBatchDepthYX4. In fact, "
+ "cudnnSetTensorNdDescriptor doesn't work for kBatchDepthYX4 at all. "
+ "Use cudnnSetTensor4DDescriptor to set cudnnTensorDescriptor_t "
+ "instead.";
+ }
std::vector<int64> phys_dims = full_dims(layout_);
std::vector<int64> phys_strides(phys_dims.size());
phys_strides[ndims_ + 1] = 1;
@@ -285,6 +298,8 @@ string BatchDescriptor::ToShortString() const {
return port::StrCat(batch, spatial, depth, suffix);
case DataLayout::kBatchDepthYX:
return port::StrCat(batch, depth, spatial, suffix);
+ case DataLayout::kBatchDepthYX4:
+ return port::StrCat(batch, depth, spatial, suffix, "(VECT_C)");
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
return ""; // Avoid return warning (unreachable)
@@ -380,6 +395,8 @@ string FilterDescriptor::ToShortString() const {
switch (layout_) {
case FilterLayout::kOutputInputYX:
return port::StrCat(od, id, spatial);
+ case FilterLayout::kOutputInputYX4:
+ return port::StrCat(od, id, spatial, "(VECT_C)");
case FilterLayout::kInputYXOutput:
return port::StrCat(id, spatial, od);
case FilterLayout::kYXInputOutput: