aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar andrehentz <ahentz@google.com>2018-04-12 11:05:30 -0700
committerGravatar GitHub <noreply@github.com>2018-04-12 11:05:30 -0700
commiteaaee5d7aa5df8157c4128568a8f7703458723cd (patch)
treeef7a2dbfe3e364c1f934676e71ee8165fe66f60f
parent079539b2e7acb1813cbfcdd2ab39f7bb77bc0467 (diff)
parent09ab7fc83e3b2b66a2d1ff68ac6ad1b56a61fcd6 (diff)
Merge pull request #17123 from hovhannesgithub/div-sub-broadcasting
Add broadcasting functionality for Div and Sub ops.
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h39
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc3
4 files changed, 34 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index 6dd243ad62..ec380c8e49 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -106,6 +106,8 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_DIV
}
+
+
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
@@ -118,7 +120,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
} else {
- context->ReportError(context, "Inputs and outputs not all float types.");
+ context->ReportError(context,
+ "Div only supports FLOAT32 and quantized UINT8 now.");
return kTfLiteError;
}
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 5f60b2d6a0..fc58c192f8 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -3938,7 +3938,7 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
using FixedPointAccum = gemmlowp::FixedPoint<int32, kAccumulationIntegerBits>;
using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
- gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
+gemmlowp::ScopedProfilingLabel label("Softmax/8bit");
const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
const int height = MatchingArraySize(input_dims, 2, output_dims, 2);
const int width = MatchingArraySize(input_dims, 1, output_dims, 1);
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 0912f5928c..0fc88b2b8e 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -1255,6 +1255,33 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
output_data, output_dims);
}
+inline void Div(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ const int batches =
+ MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+ const int height =
+ MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+ const int width =
+ MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+ const int depth =
+ MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+ for (int b = 0; b < batches; ++b) {
+ for (int y = 0; y < height; ++y) {
+ for (int x = 0; x < width; ++x) {
+ for (int c = 0; c < depth; ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ ActivationFunctionWithMinMax(
+ input1_data[Offset(input1_dims, c, x, y, b)] /
+ input2_data[Offset(input2_dims, c, x, y, b)],
+ output_activation_min, output_activation_max);
+ }
+ }
+ }
+ }
+}
+
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
@@ -1296,18 +1323,6 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Div(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
- for (int i = 0; i < flat_size; ++i) {
- output_data[i] = ActivationFunctionWithMinMax(
- input1_data[i] / input2_data[i], output_activation_min,
- output_activation_max);
- }
-}
-
inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
const float* input2_data, const Dims<4>& input2_dims,
float output_activation_min, float output_activation_max,
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 66b06aeaec..5acb356181 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -174,7 +174,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
output);
} else {
- context->ReportError(context, "Inputs and outputs not all float types.");
+ context->ReportError(context,
+ "Inputs and outputs not all float|unit8 types.");
return kTfLiteError;
}