diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-14 10:48:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-14 10:58:23 -0700 |
commit | 52d7ed1a133cb1c3a2e13532bf97beef19c1516d (patch) | |
tree | 6469c2c73872222058c20406a9ecdac1b4a8e38d /tensorflow/contrib/lite/kernels/internal | |
parent | d035a83459330c87bbc527e3d480b65f32841997 (diff) |
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213007905
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
3 files changed, 328 insertions, 66 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index 659a65a8ea..464207d739 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -4431,9 +4431,9 @@ inline void LocalResponseNormalization( } } -inline void Softmax(const float* input_data, const RuntimeShape& input_shape, - float beta, float* output_data, - const RuntimeShape& output_shape) { +inline void Softmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Softmax"); MatchingFlatSize(input_shape, output_shape); @@ -4441,7 +4441,8 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape, auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape); // Compute the exponential first, removing the max coefficient for numerical // stability. - out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta; + out_mat = + (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta; // We are separating out the exp function so that exp can be vectorized. out_mat = out_mat.array().exp(); // Normalize to get the activations. @@ -4450,10 +4451,22 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape, out_mat.array().rowwise() *= scale; } -inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_beta_multiplier, int32 input_beta_left_shift, - int diff_min, uint8* output_data, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, + float beta, float* output_data, const RuntimeShape& output_shape) { + SoftmaxParams params; + params.beta = beta; + Softmax(params, input_shape, input_data, output_shape, output_data); +} + +inline void Softmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { + const int32 input_beta_multiplier = params.input_multiplier; + const int32 input_beta_left_shift = params.input_left_shift; + const int diff_min = params.diff_min; // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -4659,10 +4672,24 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const RuntimeShape& output_shape) { + SoftmaxParams params; + params.input_multiplier = input_beta_multiplier; + params.input_left_shift = input_beta_left_shift; + params.diff_min = diff_min; + Softmax(params, input_shape, input_data, output_shape, output_data); +} + // TODO(myenik): This is the same as the reference implementation, not actually // optimized yet. -inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, - float* output_data, const RuntimeShape& output_shape) { +inline void LogSoftmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("LogSoftmax"); const int trailing_dim = input_shape.DimensionsCount() - 1; const int outer_size = @@ -4695,6 +4722,15 @@ inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + SoftmaxParams params; + // No params currently used for float LogSoftmax. + LogSoftmax(params, input_shape, input_data, output_shape, output_data); +} + template <int OutputIntegerBits, int InputIntegerBits> inline gemmlowp::FixedPoint<int32, OutputIntegerBits> log_x_for_x_greater_than_or_equal_to_1_impl( @@ -4809,12 +4845,15 @@ log_x_for_x_greater_than_or_equal_to_1( } // Currently just a copy of the reference code. -inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_multiplier, int32 input_left_shift, - int32 reverse_scaling_divisor, - int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const RuntimeShape& output_shape) { +inline void LogSoftmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8"); + const int32 input_multiplier = params.input_multiplier; + const int32 input_left_shift = params.input_left_shift; + const int32 reverse_scaling_divisor = params.reverse_scaling_divisor; + const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift; + const int diff_min = params.diff_min; // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -4896,7 +4935,24 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const float* input_data, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const RuntimeShape& output_shape) { + SoftmaxParams params; + params.input_multiplier = input_multiplier; + params.input_left_shift = input_left_shift; + params.reverse_scaling_divisor = reverse_scaling_divisor; + params.reverse_scaling_right_shift = reverse_scaling_right_shift; + params.diff_min = diff_min; + LogSoftmax(params, input_shape, input_data, output_shape, output_data); +} + +inline void Logistic(const LogisticParams& params, + const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { gemmlowp::ScopedProfilingLabel label("Logistic"); auto input_map = MapAsVector(input_data, input_shape); @@ -4905,11 +4961,23 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data, input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>()); } -inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Logistic(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + LogisticParams params; + // No params currently needed by float Logistic. + Logistic(params, input_shape, input_data, output_shape, output_data); +} + +inline void Logistic(const LogisticParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { gemmlowp::ScopedProfilingLabel label("Logistic/Uint8"); + const int32 input_zero_point = params.input_zero_point; + const int32 input_range_radius = params.input_range_radius; + const int32 input_multiplier = params.input_multiplier; + const int input_left_shift = params.input_left_shift; const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; @@ -5042,7 +5110,22 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const RuntimeShape& output_shape) { + LogisticParams params; + params.input_zero_point = input_zero_point; + params.input_range_radius = input_range_radius; + params.input_multiplier = input_multiplier; + params.input_left_shift = input_left_shift; + Logistic(params, input_shape, input_data, output_shape, output_data); +} + +inline void Logistic(const LogisticParams& params, + const RuntimeShape& input_shape, const int16* input_data, const RuntimeShape& output_shape, int16* output_data) { gemmlowp::ScopedProfilingLabel label("Logistic/Int16"); const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -5102,26 +5185,51 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy version. +inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, + const RuntimeShape& output_shape, int16* output_data) { + LogisticParams params; + // No params currently needed by int16 Logistic. + Logistic(params, input_shape, input_data, output_shape, output_data); +} + +// TODO(b/80418076): Move to legacy ops file, update invocations. // Legacy version. inline void Logistic(const int16* input_data, const RuntimeShape& input_shape, int16* output_data, const RuntimeShape& output_shape) { - Logistic(input_shape, input_data, output_shape, output_data); + LogisticParams params; + // No params currently needed by int16 Logistic. + Logistic(params, input_shape, input_data, output_shape, output_data); } -inline void Tanh(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { gemmlowp::ScopedProfilingLabel label("Tanh"); auto input_map = MapAsVector(input_data, input_shape); auto output_map = MapAsVector(output_data, output_shape); output_map.array() = input_map.array().tanh(); } -inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Tanh(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + TanhParams params; + // Currently no params needed for float Tanh. + Tanh(params, input_shape, input_data, output_shape, output_data); +} + +inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& output_shape, + uint8* output_data) { // Note that this is almost the exact same code as in Logistic(). gemmlowp::ScopedProfilingLabel label("Tanh"); + const int32 input_zero_point = params.input_zero_point; + const int32 input_range_radius = params.input_range_radius; + const int32 input_multiplier = params.input_multiplier; + const int input_left_shift = params.input_left_shift; const int size = MatchingFlatSize(input_shape, output_shape); int c = 0; @@ -5263,10 +5371,25 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, - int input_left_shift, int16* output_data, - const RuntimeShape& output_shape) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const RuntimeShape& output_shape) { + TanhParams params; + params.input_zero_point = input_zero_point; + params.input_range_radius = input_range_radius; + params.input_multiplier = input_multiplier; + params.input_left_shift = input_left_shift; + Tanh(params, input_shape, input_data, output_shape, output_data); +} + +inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, + const int16* input_data, const RuntimeShape& output_shape, + int16* output_data) { gemmlowp::ScopedProfilingLabel label("Tanh/Int16"); + const int input_left_shift = params.input_left_shift; // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); @@ -5363,6 +5486,16 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, + int input_left_shift, int16* output_data, + const RuntimeShape& output_shape) { + TanhParams params; + params.input_left_shift = input_left_shift; + Tanh(params, input_shape, input_data, output_shape, output_data); +} + template <typename SrcT, typename DstT> inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data, const RuntimeShape& output_shape, DstT* output_data) { diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 66f18ec195..111adbf5b3 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -2902,9 +2902,9 @@ inline void LocalResponseNormalization( } } -inline void Softmax(const float* input_data, const RuntimeShape& input_shape, - float beta, float* output_data, - const RuntimeShape& output_shape) { +inline void Softmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { const int trailing_dim = input_shape.DimensionsCount() - 1; const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); @@ -2923,21 +2923,33 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape, // Compute sum. float sum = 0.f; for (int c = 0; c < depth; ++c) { - sum += std::exp((input_data[i * depth + c] - max) * beta); + sum += std::exp((input_data[i * depth + c] - max) * params.beta); } // Compute result. for (int c = 0; c < depth; ++c) { output_data[i * depth + c] = - std::exp((input_data[i * depth + c] - max) * beta) / sum; + std::exp((input_data[i * depth + c] - max) * params.beta) / sum; } } } -inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_beta_multiplier, int32 input_beta_left_shift, - int diff_min, uint8* output_data, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Softmax(const float* input_data, const RuntimeShape& input_shape, + float beta, float* output_data, const RuntimeShape& output_shape) { + SoftmaxParams params; + params.beta = beta; + Softmax(params, input_shape, input_data, output_shape, output_data); +} + +inline void Softmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { + const int32 input_beta_multiplier = params.input_multiplier; + const int32 input_beta_left_shift = params.input_left_shift; + const int diff_min = params.diff_min; // The representation chosen for the input to the exp() function is Q5.26. // We need to leave extra space since values that we skip might be as large as // -32 before multiplying by input_beta_multiplier, and therefore as large as @@ -3015,8 +3027,22 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, - float* output_data, const RuntimeShape& output_shape) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy +inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_beta_multiplier, int32 input_beta_left_shift, + int diff_min, uint8* output_data, + const RuntimeShape& output_shape) { + SoftmaxParams params; + params.input_multiplier = input_beta_multiplier; + params.input_left_shift = input_beta_left_shift; + params.diff_min = diff_min; + Softmax(params, input_shape, input_data, output_shape, output_data); +} + +inline void LogSoftmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { const int trailing_dim = input_shape.DimensionsCount() - 1; const int outer_size = MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); @@ -3046,6 +3072,15 @@ inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy +inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape, + float* output_data, const RuntimeShape& output_shape) { + SoftmaxParams params; + // No params currently used for float LogSoftmax. + LogSoftmax(params, input_shape, input_data, output_shape, output_data); +} + // Although currently the name of this function says that it cannot handle // values less than 1, in practice it can handle as low as 1/x_max, where // x_max is the largest representable input. In other words, the output range @@ -3161,16 +3196,19 @@ log_x_for_x_greater_than_or_equal_to_1( input_val); } -inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_multiplier, int32 input_left_shift, - int32 reverse_scaling_divisor, - int32 reverse_scaling_right_shift, int diff_min, - uint8* output_data, const RuntimeShape& output_shape) { +inline void LogSoftmax(const SoftmaxParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { + const int32 input_multiplier = params.input_multiplier; + const int32 input_left_shift = params.input_left_shift; + const int32 reverse_scaling_divisor = params.reverse_scaling_divisor; + const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift; + const int diff_min = params.diff_min; // The representation chosen for the input to the exp() function is Q5.26. - // We need to leave extra space since values that we skip might be as large as - // -32 before multiplying by input_beta_multiplier, and therefore as large as - // -16 afterwards. Note that exp(-8) is definitely not insignificant to - // accumulation, but exp(-16) definitely is. + // We need to leave extra space since values that we skip might be as large + // as -32 before multiplying by input_beta_multiplier, and therefore as + // large as -16 afterwards. Note that exp(-8) is definitely not + // insignificant to accumulation, but exp(-16) definitely is. static constexpr int kScaledDiffIntegerBits = 5; static constexpr int kAccumulationIntegerBits = 12; static constexpr int kOutputIntegerBits = 4; @@ -3247,7 +3285,24 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const float* input_data, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_multiplier, int32 input_left_shift, + int32 reverse_scaling_divisor, + int32 reverse_scaling_right_shift, int diff_min, + uint8* output_data, const RuntimeShape& output_shape) { + SoftmaxParams params; + params.input_multiplier = input_multiplier; + params.input_left_shift = input_left_shift; + params.reverse_scaling_divisor = reverse_scaling_divisor; + params.reverse_scaling_right_shift = reverse_scaling_right_shift; + params.diff_min = diff_min; + LogSoftmax(params, input_shape, input_data, output_shape, output_data); +} + +inline void Logistic(const LogisticParams& params, + const RuntimeShape& input_shape, const float* input_data, const RuntimeShape& output_shape, float* output_data) { const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -3258,10 +3313,22 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data, } } -inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Logistic(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + LogisticParams params; + // No params currently needed by float Logistic. + Logistic(params, input_shape, input_data, output_shape, output_data); +} + +inline void Logistic(const LogisticParams& params, + const RuntimeShape& input_shape, const uint8* input_data, + const RuntimeShape& output_shape, uint8* output_data) { + const int32 input_zero_point = params.input_zero_point; + const int32 input_range_radius = params.input_range_radius; + const int32 input_multiplier = params.input_multiplier; + const int input_left_shift = params.input_left_shift; const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { @@ -3296,7 +3363,22 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const RuntimeShape& output_shape) { + LogisticParams params; + params.input_zero_point = input_zero_point; + params.input_range_radius = input_range_radius; + params.input_multiplier = input_multiplier; + params.input_left_shift = input_left_shift; + Logistic(params, input_shape, input_data, output_shape, output_data); +} + +inline void Logistic(const LogisticParams& params, + const RuntimeShape& input_shape, const int16* input_data, const RuntimeShape& output_shape, int16* output_data) { const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -3314,8 +3396,18 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, } } -inline void Tanh(const RuntimeShape& input_shape, const float* input_data, - const RuntimeShape& output_shape, float* output_data) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Logistic(const RuntimeShape& input_shape, const int16* input_data, + const RuntimeShape& output_shape, int16* output_data) { + LogisticParams params; + // No params currently needed by int16 Logistic. + Logistic(params, input_shape, input_data, output_shape, output_data); +} + +inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, + const float* input_data, const RuntimeShape& output_shape, + float* output_data) { const int flat_size = MatchingFlatSize(input_shape, output_shape); for (int i = 0; i < flat_size; i++) { @@ -3325,10 +3417,22 @@ inline void Tanh(const RuntimeShape& input_shape, const float* input_data, } } -inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, - int32 input_zero_point, int32 input_range_radius, - int32 input_multiplier, int input_left_shift, - uint8* output_data, const RuntimeShape& output_shape) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Tanh(const RuntimeShape& input_shape, const float* input_data, + const RuntimeShape& output_shape, float* output_data) { + TanhParams params; + // Currently no params needed for float Tanh. + Tanh(params, input_shape, input_data, output_shape, output_data); +} + +inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, + const uint8* input_data, const RuntimeShape& output_shape, + uint8* output_data) { + const int32 input_zero_point = params.input_zero_point; + const int32 input_range_radius = params.input_range_radius; + const int32 input_multiplier = params.input_multiplier; + const int input_left_shift = params.input_left_shift; const int32 output_zero_point = 128; const int flat_size = MatchingFlatSize(input_shape, output_shape); @@ -3365,9 +3469,24 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, } } -inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, - int input_left_shift, int16* output_data, - const RuntimeShape& output_shape) { +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const RuntimeShape& output_shape) { + TanhParams params; + params.input_zero_point = input_zero_point; + params.input_range_radius = input_range_radius; + params.input_multiplier = input_multiplier; + params.input_left_shift = input_left_shift; + Tanh(params, input_shape, input_data, output_shape, output_data); +} + +inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape, + const int16* input_data, const RuntimeShape& output_shape, + int16* output_data) { + const int input_left_shift = params.input_left_shift; // Support for shifts is limited until we have a parameterized version of // SaturatingRoundingMultiplyByPOT(). TFLITE_DCHECK_GE(input_left_shift, 0); @@ -3398,6 +3517,16 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, } } +// TODO(b/80418076): Move to legacy ops file, update invocations. +// Legacy. +inline void Tanh(const int16* input_data, const RuntimeShape& input_shape, + int input_left_shift, int16* output_data, + const RuntimeShape& output_shape) { + TanhParams params; + params.input_left_shift = input_left_shift; + Tanh(params, input_shape, input_data, output_shape, output_data); +} + inline void Dequantize(const tflite::DequantizationParams& op_params, const RuntimeShape& input_shape, const uint8* input_data, const RuntimeShape& output_shape, float* output_data) { diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 023707d466..87e8ff0346 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -885,8 +885,8 @@ struct SoftmaxParams { // for LogSoftmax. double beta; // uint8 inference params. Used even when beta defaults to 1.0. - int32 input_beta_multiplier; - int32 input_beta_left_shift; + int32 input_multiplier; + int32 input_left_shift; // Reverse scaling is only used by LogSoftmax. int32 reverse_scaling_divisor; int32 reverse_scaling_right_shift; |