diff options
author | Jingyue Wu <jingyue@google.com> | 2017-06-12 21:42:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-12 21:45:49 -0700 |
commit | de702291b0d160ae8ef13448311b666d71c0d47a (patch) | |
tree | 0afdb08b15db69c1417cbfbac95c0263b03575b8 /tensorflow/stream_executor/stream.cc | |
parent | 4cc0caadf6b6a1b3b589595d6a752802f3d23356 (diff) |
TransformTensor supports NCHW_VECT_C layout and int8 data type.
o Add new DataType kInt8
o Add new DataLayout kBatchDepthYX4 and new FilterLayout kOutputInputYX4, both
of which map to CUDNN_TENSOR_NCHW_VECT_C
o Change (Then|Do)TransformTensor to take input and output element types.
o Add new tensor format FORMAT_NCHW_VECT_C, and change the utility functions in
tensor_format.h to work with the new format.
PiperOrigin-RevId: 158806412
Diffstat (limited to 'tensorflow/stream_executor/stream.cc')
-rw-r--r-- | tensorflow/stream_executor/stream.cc | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index c27dda88cc..fff9accae3 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -166,6 +166,19 @@ string ToVlogString(dnn::DepthToSpaceLayout depth_to_space_layout) { return "unknown DepthToSpaceLayout"; } +string ToVlogString(dnn::DataType data_type) { + switch (data_type) { + case dnn::DataType::kFloat: + return "dnn::DataType::kFloat"; + case dnn::DataType::kDouble: + return "dnn::DataType::kDouble"; + case dnn::DataType::kHalf: + return "dnn::DataType::kHalf"; + case dnn::DataType::kInt8: + return "dnn::DataType::kInt8"; + } +} + // Used together with PARAM to VLOG calls made to the stream. Intended // to be used like this: // @@ -4390,15 +4403,18 @@ Stream &Stream::ThenRnnBackward( } Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, - const DeviceMemory<float> &input_data, + dnn::DataType input_type, + const DeviceMemoryBase &input_data, const dnn::BatchDescriptor &output_desc, - DeviceMemory<float> *output_data) { - VLOG_CALL(PARAM(input_desc), PARAM(input_data), PARAM(output_desc), - PARAM(output_data)); + dnn::DataType output_type, + DeviceMemoryBase *output_data) { + VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data), + PARAM(output_desc), PARAM(output_type), PARAM(output_data)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError(dnn->DoTransformTensor(this, input_desc, input_data, - output_desc, output_data)); + CheckError(dnn->DoTransformTensor(this, input_desc, input_type, + input_data, output_desc, output_type, + output_data)); } else { SetErrorAndLogNoDnnSupport(); } |