aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.cc
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2017-06-12 21:42:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-12 21:45:49 -0700
commitde702291b0d160ae8ef13448311b666d71c0d47a (patch)
tree0afdb08b15db69c1417cbfbac95c0263b03575b8 /tensorflow/stream_executor/stream.cc
parent4cc0caadf6b6a1b3b589595d6a752802f3d23356 (diff)
TransformTensor supports NCHW_VECT_C layout and int8 data type.
o Add new DataType kInt8 o Add new DataLayout kBatchDepthYX4 and new FilterLayout kOutputInputYX4, both of which map to CUDNN_TENSOR_NCHW_VECT_C o Change (Then|Do)TransformTensor to take input and output element types. o Add new tensor format FORMAT_NCHW_VECT_C, and change the utility functions in tensor_format.h to work with the new format. PiperOrigin-RevId: 158806412
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r--tensorflow/stream_executor/stream.cc28
1 files changed, 22 insertions, 6 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index c27dda88cc..fff9accae3 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -166,6 +166,19 @@ string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) {
return "unknown DepthToSpaceLayout";
}
+string ToVlogString(dnn::DataType data_type) {
+ switch (data_type) {
+ case dnn::DataType::kFloat:
+ return "dnn::DataType::kFloat";
+ case dnn::DataType::kDouble:
+ return "dnn::DataType::kDouble";
+ case dnn::DataType::kHalf:
+ return "dnn::DataType::kHalf";
+ case dnn::DataType::kInt8:
+ return "dnn::DataType::kInt8";
+ }
+}
+
// Used together with PARAM to VLOG calls made to the stream. Intended
// to be used like this:
//
@@ -4390,15 +4403,18 @@ Stream &Stream::ThenRnnBackward(
}
Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
- const DeviceMemory<float> &input_data,
+ dnn::DataType input_type,
+ const DeviceMemoryBase &input_data,
const dnn::BatchDescriptor &output_desc,
- DeviceMemory<float> *output_data) {
- VLOG_CALL(PARAM(input_desc), PARAM(input_data), PARAM(output_desc),
- PARAM(output_data));
+ dnn::DataType output_type,
+ DeviceMemoryBase *output_data) {
+ VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data),
+ PARAM(output_desc), PARAM(output_type), PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(dnn->DoTransformTensor(this, input_desc, input_data,
- output_desc, output_data));
+ CheckError(dnn->DoTransformTensor(this, input_desc, input_type,
+ input_data, output_desc, output_type,
+ output_data));
} else {
SetErrorAndLogNoDnnSupport();
}