diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-19 12:26:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-19 12:29:52 -0700 |
commit | c8e17b08263f3ba61a5fdf785e231e1c9f4029ca (patch) | |
tree | e20f4fd1704ae3f9a8cf94c6040274b8e301d108 /tensorflow/contrib/lite/kernels | |
parent | 238424ffcf04c38561ed48ebadb16b3b3a770e2e (diff) |
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213673402
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h | 60 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h | 74 |
2 files changed, 90 insertions, 44 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 6a7e664e85..1a2d45166a 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -3804,11 +3804,11 @@ inline void LstmCell( uint8* concat_temp_data_uint8, const RuntimeShape& unextended_activ_temp_shape, int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) { + gemmlowp::ScopedProfilingLabel label( + "LstmCell/quantized (8bit external, 16bit internal)"); int32 weights_zero_point = params.weights_zero_point; int32 accum_multiplier = params.accum_multiplier; int accum_shift = params.accum_shift; - gemmlowp::ScopedProfilingLabel label( - "LstmCell/quantized (8bit external, 16bit internal)"); TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4); TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4); @@ -5063,8 +5063,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, LogSoftmax(params, input_shape, input_data, output_shape, output_data); } -inline void Logistic(const LogisticParams& params, - const RuntimeShape& input_shape, const float* input_data, +inline void Logistic(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Logistic"); auto input_map = MapAsVector(input_data, input_shape); @@ -5073,13 +5072,13 @@ inline void Logistic(const LogisticParams& params, input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()); } -// TODO(b/80418076): Move to legacy ops file, update invocations. -// Legacy. -inline void Logistic(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { - LogisticParams params; - // No params currently needed by float Logistic. - Logistic(params, input_shape, input_data, output_shape, output_data); +// Convenience version that allows, for example, generated-code calls to be +// uniform between data types. +inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { + // Drop params: not needed. + Logistic(input_shape, input_data, output_shape, output_data); } inline void Logistic(const LogisticParams& params, @@ -5315,22 +5314,21 @@ inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, Logistic(params, input_shape, input_data, output_shape, output_data); } -inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, - const float* input_data, const RuntimeShape& output_shape, - float* output_data) { +inline void Tanh(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Tanh"); auto input_map = MapAsVector(input_data, input_shape); auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().tanh(); } -// TODO(b/80418076): Move to legacy ops file, update invocations. -// Legacy. -inline void Tanh(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { - TanhParams params; - // Currently no params needed for float Tanh. - Tanh(params, input_shape, input_data, output_shape, output_data); +// Convenience version that allows, for example, generated-code calls to be +// uniform between data types. +inline void Tanh(const TanhParams&, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { + // Drop params: not needed. + Tanh(input_shape, input_data, output_shape, output_data); } inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, @@ -6385,6 +6383,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data, output_map.array() = input1_map.array().min(min_value); } +// Convenience version that allows, for example, generated-code calls to be +// the same as other binary ops. +template <typename T> +inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape&, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + // Drop shape of second input: not needed. + Minimum(input1_shape, input1_data, input2_data, output_shape, output_data); +} + template <typename T> void Maximum(const RuntimeShape& input1_shape, const T* input1_data, const T* input2_data, const RuntimeShape& output_shape, @@ -6396,6 +6404,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data, output_map.array() = input1_map.array().max(max_value); } +// Convenience version that allows, for example, generated-code calls to be +// the same as other binary ops. +template <typename T> +inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape&, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + // Drop shape of second input: not needed. + Maximum(input1_shape, input1_data, input2_data, output_shape, output_data); +} + template <typename T> void TransposeIm2col(const ConvParams& params, uint8 zero_byte, const RuntimeShape& input_shape, const T* input_data, diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 76fa1944bc..bb1d30b216 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -1916,7 +1916,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const float* input2_data, const RuntimeShape& output_shape, float* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float"); + gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -1957,7 +1957,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const uint8* input2_data, const RuntimeShape& output_shape, uint8* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8"); + gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -2021,7 +2021,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params, const int32* input2_data, const RuntimeShape& output_shape, int32* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32"); + gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -2061,7 +2061,7 @@ void BroadcastSub4DSlow(const ArithmeticParams& params, const RuntimeShape& input1_shape, const T* input1_data, const RuntimeShape& input2_shape, const T* input2_data, const RuntimeShape& output_shape, T* output_data) { - gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated"); + gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated"); NdArrayDesc<4> desc1; NdArrayDesc<4> desc2; NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, @@ -3637,8 +3637,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, LogSoftmax(params, input_shape, input_data, output_shape, output_data); } -inline void Logistic(const LogisticParams& params, - const RuntimeShape& input_shape, const float* input_data, +inline void Logistic(const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -3649,13 +3648,13 @@ inline void Logistic(const LogisticParams& params, } } -// TODO(b/80418076): Move to legacy ops file, update invocations. -// Legacy. -inline void Logistic(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { - LogisticParams params; - // No params currently needed by float Logistic. - Logistic(params, input_shape, input_data, output_shape, output_data); +// Convenience version that allows, for example, generated-code calls to be +// uniform between data types. +inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { + // Drop params: not needed. + Logistic(input_shape, input_data, output_shape, output_data); } inline void Logistic(const LogisticParams& params, @@ -3741,9 +3740,8 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, Logistic(params, input_shape, input_data, output_shape, output_data); } -inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, - const float* input_data, const RuntimeShape& output_shape, - float* output_data) { +inline void Tanh(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { @@ -3753,13 +3751,13 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, } } -// TODO(b/80418076): Move to legacy ops file, update invocations. -// Legacy. -inline void Tanh(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { - TanhParams params; - // Currently no params needed for float Tanh. - Tanh(params, input_shape, input_data, output_shape, output_data); +// Convenience version that allows, for example, generated-code calls to be +// uniform between data types. +inline void Tanh(const TanhParams&, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { + // Drop params: not needed. + Tanh(input_shape, input_data, output_shape, output_data); } inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, @@ -4735,6 +4733,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data, } } +// Convenience version that allows, for example, generated-code calls to be +// the same as other binary ops. +template <typename T> +inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape&, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + // Drop shape of second input: not needed. + Minimum(input1_shape, input1_data, input2_data, output_shape, output_data); +} + template <typename T> void Maximum(const RuntimeShape& input1_shape, const T* input1_data, const T* input2_data, const RuntimeShape& output_shape, @@ -4747,6 +4755,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data, } } +// Convenience version that allows, for example, generated-code calls to be +// the same as other binary ops. +template <typename T> +inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data, + const RuntimeShape&, const T* input2_data, + const RuntimeShape& output_shape, T* output_data) { + // Drop shape of second input: not needed. + Maximum(input1_shape, input1_data, input2_data, output_shape, output_data); +} + template <typename T, typename Op> void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape, const T* input1_data, @@ -4822,6 +4840,16 @@ void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data, std::greater<T1>()); } +// Convenience version that allows, for example, generated-code calls to be +// the same as other binary ops. +template <typename T1, typename T2, typename T3> +inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data, + const RuntimeShape& input2_shape, const T3* input2_data, + const RuntimeShape& output_shape, T2* output_data) { + // Drop shape of second input: not needed. + ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data); +} + template <typename T> void Transpose(const TransposeParams& params, const RuntimeShape& unextended_input_shape, const T* input_data, |