aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream.h
diff options
context:
space:
mode:
authorGravatar Jingyue Wu <jingyue@google.com>2017-06-12 21:42:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-12 21:45:49 -0700
commitde702291b0d160ae8ef13448311b666d71c0d47a (patch)
tree0afdb08b15db69c1417cbfbac95c0263b03575b8 /tensorflow/stream_executor/stream.h
parent4cc0caadf6b6a1b3b589595d6a752802f3d23356 (diff)
TransformTensor supports NCHW_VECT_C layout and int8 data type.
o Add new DataType kInt8 o Add new DataLayout kBatchDepthYX4 and new FilterLayout kOutputInputYX4, both of which map to CUDNN_TENSOR_NCHW_VECT_C o Change (Then|Do)TransformTensor to take input and output element types. o Add new tensor format FORMAT_NCHW_VECT_C, and change the utility functions in tensor_format.h to work with the new format. PiperOrigin-RevId: 158806412
Diffstat (limited to 'tensorflow/stream_executor/stream.h')
-rw-r--r--tensorflow/stream_executor/stream.h18
1 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index b919c0095d..b07d3021c9 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1652,9 +1652,23 @@ class Stream {
// Enqueue onto the stream a operation that transforms a tensor.
// See DnnSupport::DoTransformTensor for more details.
Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
- const DeviceMemory<float> &input_data,
+ dnn::DataType input_type,
+ const DeviceMemoryBase &input_data,
const dnn::BatchDescriptor &output_desc,
- DeviceMemory<float> *output_data);
+ dnn::DataType output_type,
+ DeviceMemoryBase *output_data);
+
+ // The templated version of the above ThenTransformTensor. Useful when the
+ // input and output types are statically known.
+ template <typename InElemT, typename OutElemT>
+ Stream &ThenTransformTensor(const dnn::BatchDescriptor &input_desc,
+ const DeviceMemory<InElemT> &input_data,
+ const dnn::BatchDescriptor &output_desc,
+ DeviceMemory<OutElemT> *output_data) {
+ return ThenTransformTensor(input_desc, dnn::ToDataType<InElemT>(),
+ input_data, output_desc,
+ dnn::ToDataType<OutElemT>(), output_data);
+ }
// (Synchronously) block the host code waiting for the operations
// entrained on the stream (enqueued to this point in program