diff options
author | 2016-08-17 14:22:14 -0800 | |
---|---|---|
committer | 2016-08-17 15:34:07 -0700 | |
commit | 204eac513d875a4fb437d4985b2ad520f79c262c (patch) | |
tree | fadcebb02c4c78363449f28ea56fe67aa2bcd63b /tensorflow/core/kernels/conv_ops.h | |
parent | 699154cedccef7a8a44de4be2d11dc737513f941 (diff) |
Relax static shape information requirements for depthwise and separable conv2d.
The current implementation requires that if rank of a tensor is statically
known, certain dimensions needs to be known statically as well.
Change: 130570253
Diffstat (limited to 'tensorflow/core/kernels/conv_ops.h')
-rw-r--r-- | tensorflow/core/kernels/conv_ops.h | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h new file mode 100644 index 0000000000..d09db3dc15 --- /dev/null +++ b/tensorflow/core/kernels/conv_ops.h @@ -0,0 +1,58 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_CONV_OPS_H_ +#define TENSORFLOW_KERNELS_CONV_OPS_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/util/tensor_format.h" + +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/conv_ops_gpu.h" +#include "tensorflow/core/platform/stream_executor.h" +#endif // GOOGLE_CUDA + +namespace tensorflow { + +// Forward declaration. +class OpKernelContext; + +template <typename Device, typename T> +class LaunchConv2DOp { + public: + void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Eigen::PaddingType& padding, Tensor* output, + TensorFormat data_format); +}; + +#ifdef GOOGLE_CUDA +template <typename T> +class LaunchConv2DOp<Eigen::GpuDevice, T> { + public: + void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, + const Tensor& input, const Tensor& filter, int row_stride, + int col_stride, const Eigen::PaddingType& padding, Tensor* output, + TensorFormat data_format); + + private: + AutoTuneMap<ConvParameters, perftools::gputools::dnn::AlgorithmConfig> + autotune_results_; +}; +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_CONV_OPS_H |