aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-19 12:26:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 12:29:52 -0700
commitc8e17b08263f3ba61a5fdf785e231e1c9f4029ca (patch)
treee20f4fd1704ae3f9a8cf94c6040274b8e301d108 /tensorflow/contrib/lite/kernels
parent238424ffcf04c38561ed48ebadb16b3b3a770e2e (diff)
Convert more kernel signatures to use runtime shapes.
PiperOrigin-RevId: 213673402
Diffstat (limited to 'tensorflow/contrib/lite/kernels')
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h60
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h74
2 files changed, 90 insertions, 44 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 6a7e664e85..1a2d45166a 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -3804,11 +3804,11 @@ inline void LstmCell(
uint8* concat_temp_data_uint8,
const RuntimeShape& unextended_activ_temp_shape,
int16* activ_temp_data_int16, gemmlowp::GemmContext* gemm_context) {
+ gemmlowp::ScopedProfilingLabel label(
+ "LstmCell/quantized (8bit external, 16bit internal)");
int32 weights_zero_point = params.weights_zero_point;
int32 accum_multiplier = params.accum_multiplier;
int accum_shift = params.accum_shift;
- gemmlowp::ScopedProfilingLabel label(
- "LstmCell/quantized (8bit external, 16bit internal)");
TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_prev_activ_shape.DimensionsCount(), 4);
TFLITE_DCHECK_LE(unextended_bias_shape.DimensionsCount(), 4);
@@ -5063,8 +5063,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
LogSoftmax(params, input_shape, input_data, output_shape, output_data);
}
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Logistic");
auto input_map = MapAsVector(input_data, input_shape);
@@ -5073,13 +5072,13 @@ inline void Logistic(const LogisticParams& params,
input_map.array().unaryExpr(Eigen::internal::scalar_sigmoid_op<float>());
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -5315,22 +5314,21 @@ inline void Logistic(const int16* input_data, const RuntimeShape& input_shape,
Logistic(params, input_shape, input_data, output_shape, output_data);
}
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Tanh");
auto input_map = MapAsVector(input_data, input_shape);
auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().tanh();
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -6385,6 +6383,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().min(min_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -6396,6 +6404,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
output_map.array() = input1_map.array().max(max_value);
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void TransposeIm2col(const ConvParams& params, uint8 zero_byte,
const RuntimeShape& input_shape, const T* input_data,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 76fa1944bc..bb1d30b216 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -1916,7 +1916,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/float");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/float");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -1957,7 +1957,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const uint8* input2_data,
const RuntimeShape& output_shape,
uint8* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/uint8");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/uint8");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2021,7 +2021,7 @@ inline void BroadcastSub4DSlow(const ArithmeticParams& params,
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/int32");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/int32");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -2061,7 +2061,7 @@ void BroadcastSub4DSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
const RuntimeShape& input2_shape, const T* input2_data,
const RuntimeShape& output_shape, T* output_data) {
- gemmlowp::ScopedProfilingLabel label("BroadcastAdd4DSlow/templated");
+ gemmlowp::ScopedProfilingLabel label("BroadcastSub4DSlow/templated");
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
@@ -3637,8 +3637,7 @@ inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
LogSoftmax(params, input_shape, input_data, output_shape, output_data);
}
-inline void Logistic(const LogisticParams& params,
- const RuntimeShape& input_shape, const float* input_data,
+inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -3649,13 +3648,13 @@ inline void Logistic(const LogisticParams& params,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Logistic(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- LogisticParams params;
- // No params currently needed by float Logistic.
- Logistic(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Logistic(const LogisticParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Logistic(input_shape, input_data, output_shape, output_data);
}
inline void Logistic(const LogisticParams& params,
@@ -3741,9 +3740,8 @@ inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
Logistic(params, input_shape, input_data, output_shape, output_data);
}
-inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
- const float* input_data, const RuntimeShape& output_shape,
- float* output_data) {
+inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
@@ -3753,13 +3751,13 @@ inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
}
}
-// TODO(b/80418076): Move to legacy ops file, update invocations.
-// Legacy.
-inline void Tanh(const RuntimeShape& input_shape, const float* input_data,
- const RuntimeShape& output_shape, float* output_data) {
- TanhParams params;
- // Currently no params needed for float Tanh.
- Tanh(params, input_shape, input_data, output_shape, output_data);
+// Convenience version that allows, for example, generated-code calls to be
+// uniform between data types.
+inline void Tanh(const TanhParams&, const RuntimeShape& input_shape,
+ const float* input_data, const RuntimeShape& output_shape,
+ float* output_data) {
+ // Drop params: not needed.
+ Tanh(input_shape, input_data, output_shape, output_data);
}
inline void Tanh(const TanhParams& params, const RuntimeShape& input_shape,
@@ -4735,6 +4733,16 @@ void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Minimum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -4747,6 +4755,16 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T>
+inline void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape&, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Drop shape of second input: not needed.
+ Maximum(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T, typename Op>
void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
@@ -4822,6 +4840,16 @@ void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
std::greater<T1>());
}
+// Convenience version that allows, for example, generated-code calls to be
+// the same as other binary ops.
+template <typename T1, typename T2, typename T3>
+inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const RuntimeShape& input2_shape, const T3* input2_data,
+ const RuntimeShape& output_shape, T2* output_data) {
+ // Drop shape of second input: not needed.
+ ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
+}
+
template <typename T>
void Transpose(const TransposeParams& params,
const RuntimeShape& unextended_input_shape, const T* input_data,