aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-28 01:52:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 01:57:31 -0700
commit3e13ae966115b1aaf793601b0647b40efb25a2da (patch)
tree50aa649f843a698d23c99a4d6ef9c5adc6752895 /tensorflow/contrib/lite/toco/tflite
parentf255b51c6e637ac7701996b4457157d3c313dca4 (diff)
Implementation of reduce_any.
PiperOrigin-RevId: 210507220
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc27
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc27
2 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index e9383098cc..f687e9689e 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -769,7 +769,7 @@ class Sum
};
class ReduceMax
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -788,7 +788,7 @@ class ReduceMax
};
class ReduceMin
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -807,7 +807,26 @@ class ReduceMin
};
class ReduceProd
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->keep_dims = options.keep_dims();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
+class ReduceAny
+ : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -1336,6 +1355,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kReduceMax));
ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
OperatorType::kReduceMin));
+ ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
+ OperatorType::kAny));
ops.push_back(
MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index bb0b457483..6da9317e4f 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -97,6 +97,16 @@ class OperatorTest : public ::testing::Test {
ASSERT_NE(nullptr, output_toco_op.get());
}
+
+ template <typename T>
+ void CheckReducerOperator(const string& name, OperatorType type) {
+ T op;
+
+ op.keep_dims = false;
+
+ auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
+ EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
+ }
};
TEST_F(OperatorTest, SimpleOperators) {
@@ -144,13 +154,16 @@ TEST_F(OperatorTest, BuiltinAdd) {
output_toco_op->fused_activation_function);
}
-TEST_F(OperatorTest, BuiltinMean) {
- MeanOperator op;
- op.keep_dims = false;
-
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op);
- EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
+TEST_F(OperatorTest, BuiltinReducerOps) {
+ CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
+ CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
+ CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
+ OperatorType::kReduceProd);
+ CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
+ OperatorType::kReduceMax);
+ CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
+ OperatorType::kReduceMin);
+ CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
}
TEST_F(OperatorTest, BuiltinCast) {