aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-14 10:48:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-14 10:58:23 -0700
commit52d7ed1a133cb1c3a2e13532bf97beef19c1516d (patch)
tree6469c2c73872222058c20406a9ecdac1b4a8e38d /tensorflow/contrib/lite/kernels/internal
parentd035a83459330c87bbc527e3d480b65f32841997 (diff)
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213007905
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h193
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h197
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h4
3 files changed, 328 insertions, 66 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 659a65a8ea..464207d739 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -4431,9 +4431,9 @@ inline void LocalResponseNormalization(
}
}
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Softmax");
MatchingFlatSize(input_shape, output_shape);
@@ -4441,7 +4441,8 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
auto out_mat = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Compute the exponential first, removing the max coefficient for numerical
// stability.
- out_mat = (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * beta;
+ out_mat =
+ (in_mat.rowwise() - in_mat.colwise().maxCoeff()).array() * params.beta;
// We are separating out the exp function so that exp can be vectorized.
out_mat = out_mat.array().exp();
// Normalize to get the activations.
@@ -4450,10 +4451,22 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
out_mat.array().rowwise() *= scale;
}
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_beta_multiplier = params.input_multiplier;
+ const int32 input_beta_left_shift = params.input_left_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4659,10 +4672,24 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
// TODO(myenik): This is the same as the reference implementation, not actually
// optimized yet.
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax");
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
@@ -4695,6 +4722,15 @@ inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
template <int OutputIntegerBits, int InputIntegerBits>
inline gemmlowp::FixedPoint<int32, OutputIntegerBits>
log_x_for_x_greater_than_or_equal_to_1_impl(
@@ -4809,12 +4845,15 @@ log_x_for_x_greater_than_or_equal_to_1(
}
// Currently just a copy of the reference code.
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("LogSoftmax/Uint8");
+ const int32 input_multiplier = params.input_multiplier;
+ const int32 input_left_shift = params.input_left_shift;
+ const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+ const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -4896,7 +4935,24 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
@@ -4905,11 +4961,23 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ LogisticParams params;
+ // No params currently needed by float Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Uint8");
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
@@ -5042,7 +5110,22 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic/Int16");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -5102,26 +5185,51 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy version.
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
// Legacy version.
inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
int16* output_data, const RuntimeShape& output_shape) {
- Logistic(input_shape, input_data, output_shape, output_data);
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
}
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ TanhParams params;
+ // Currently no params needed for float Tanh.
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
// Note that this is almost the exact same code as in Logistic().
gemmlowp::ScopedProfilingLabel label("Tanh");
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int size = MatchingFlatSize(input_shape, output_shape);
int c = 0;
@@ -5263,10 +5371,25 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const int16* input_data, const RuntimeShape& output_shape,
+ int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh/Int16");
+ const int input_left_shift = params.input_left_shift;
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
@@ -5363,6 +5486,16 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
template <typename SrcT, typename DstT>
inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
const RuntimeShape& output_shape, DstT* output_data) {
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 66f18ec195..111adbf5b3 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -2902,9 +2902,9 @@ inline void LocalResponseNormalization(
}
}
-inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
- float beta, float* output_data,
- const RuntimeShape& output_shape) {
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -2923,21 +2923,33 @@ inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
// Compute sum.
float sum = 0.f;
for (int c = 0; c < depth; ++c) {
- sum += std::exp((input_data[i * depth + c] - max) * beta);
+ sum += std::exp((input_data[i * depth + c] - max) * params.beta);
}
// Compute result.
for (int c = 0; c < depth; ++c) {
output_data[i * depth + c] =
- std::exp((input_data[i * depth + c] - max) * beta) / sum;
+ std::exp((input_data[i * depth + c] - max) * params.beta) / sum;
}
}
}
-inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_beta_multiplier, int32 input_beta_left_shift,
- int diff_min, uint8* output_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
+ float beta, float* output_data,
const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.beta = beta;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Softmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_beta_multiplier = params.input_multiplier;
+ const int32 input_beta_left_shift = params.input_left_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
// We need to leave extra space since values that we skip might be as large as
// -32 before multiplying by input_beta_multiplier, and therefore as large as
@@ -3015,8 +3027,22 @@ inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_beta_multiplier, int32 input_beta_left_shift,
+ int diff_min, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_beta_multiplier;
+ params.input_left_shift = input_beta_left_shift;
+ params.diff_min = diff_min;
+ Softmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -3046,6 +3072,15 @@ inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy
+inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ // No params currently used for float LogSoftmax.
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
// Although currently the name of this function says that it cannot handle
// values less than 1, in practice it can handle as low as 1/x_max, where
// x_max is the largest representable input. In other words, the output range
@@ -3161,16 +3196,19 @@ log_x_for_x_greater_than_or_equal_to_1(
input_val);
}
-inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_multiplier, int32 input_left_shift,
- int32 reverse_scaling_divisor,
- int32 reverse_scaling_right_shift, int diff_min,
- uint8* output_data, const RuntimeShape& output_shape) {
+inline void LogSoftmax(const SoftmaxParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_multiplier = params.input_multiplier;
+ const int32 input_left_shift = params.input_left_shift;
+ const int32 reverse_scaling_divisor = params.reverse_scaling_divisor;
+ const int32 reverse_scaling_right_shift = params.reverse_scaling_right_shift;
+ const int diff_min = params.diff_min;
// The representation chosen for the input to the exp() function is Q5.26.
- // We need to leave extra space since values that we skip might be as large as
- // -32 before multiplying by input_beta_multiplier, and therefore as large as
- // -16 afterwards. Note that exp(-8) is definitely not insignificant to
- // accumulation, but exp(-16) definitely is.
+ // We need to leave extra space since values that we skip might be as large
+ // as -32 before multiplying by input_beta_multiplier, and therefore as
+ // large as -16 afterwards. Note that exp(-8) is definitely not
+ // insignificant to accumulation, but exp(-16) definitely is.
static constexpr int kScaledDiffIntegerBits = 5;
static constexpr int kAccumulationIntegerBits = 12;
static constexpr int kOutputIntegerBits = 4;
@@ -3247,7 +3285,24 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_multiplier, int32 input_left_shift,
+ int32 reverse_scaling_divisor,
+ int32 reverse_scaling_right_shift, int diff_min,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ SoftmaxParams params;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ params.reverse_scaling_divisor = reverse_scaling_divisor;
+ params.reverse_scaling_right_shift = reverse_scaling_right_shift;
+ params.diff_min = diff_min;
+ LogSoftmax(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3258,10 +3313,22 @@ inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
}
}
-inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ LogisticParams params;
+ // No params currently needed by float Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3296,7 +3363,22 @@ inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ LogisticParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Logistic(const LogisticParams& params,
+ const RuntimeShape& input_shape, const int16* input_data,
const RuntimeShape& output_shape, int16* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3314,8 +3396,18 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
}
}
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ LogisticParams params;
+ // No params currently needed by int16 Logistic.
+ Logistic(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3325,10 +3417,22 @@ inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
}
}
-inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
- int32 input_zero_point, int32 input_range_radius,
- int32 input_multiplier, int input_left_shift,
- uint8* output_data, const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ TanhParams params;
+ // Currently no params needed for float Tanh.
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const uint8* input_data, const RuntimeShape& output_shape,
+ uint8* output_data) {
+ const int32 input_zero_point = params.input_zero_point;
+ const int32 input_range_radius = params.input_range_radius;
+ const int32 input_multiplier = params.input_multiplier;
+ const int input_left_shift = params.input_left_shift;
const int32 output_zero_point = 128;
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3365,9 +3469,24 @@ inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
}
}
-inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
- int input_left_shift, int16* output_data,
- const RuntimeShape& output_shape) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
+ int32 input_zero_point, int32 input_range_radius,
+ int32 input_multiplier, int input_left_shift,
+ uint8* output_data, const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_zero_point = input_zero_point;
+ params.input_range_radius = input_range_radius;
+ params.input_multiplier = input_multiplier;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
+inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
+ const int16* input_data, const RuntimeShape& output_shape,
+ int16* output_data) {
+ const int input_left_shift = params.input_left_shift;
// Support for shifts is limited until we have a parameterized version of
// SaturatingRoundingMultiplyByPOT().
TFLITE_DCHECK_GE(input_left_shift, 0);
@@ -3398,6 +3517,16 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
+ int input_left_shift, int16* output_data,
+ const RuntimeShape& output_shape) {
+ TanhParams params;
+ params.input_left_shift = input_left_shift;
+ Tanh(params, input_shape, input_data, output_shape, output_data);
+}
+
inline void Dequantize(const tflite::DequantizationParams& op_params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& output_shape, float* output_data) {
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 023707d466..87e8ff0346 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -885,8 +885,8 @@ struct SoftmaxParams {
// for LogSoftmax.
double beta;
// uint8 inference params. Used even when beta defaults to 1.0.
- int32 input_beta_multiplier;
- int32 input_beta_left_shift;
+ int32 input_multiplier;
+ int32 input_left_shift;
// Reverse scaling is only used by LogSoftmax.
int32 reverse_scaling_divisor;
int32 reverse_scaling_right_shift;