diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-28 01:52:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 01:57:31 -0700 |
commit | 3e13ae966115b1aaf793601b0647b40efb25a2da (patch) | |
tree | 50aa649f843a698d23c99a4d6ef9c5adc6752895 /tensorflow/contrib/lite/toco/tflite | |
parent | f255b51c6e637ac7701996b4457157d3c313dca4 (diff) |
Implementation of reduce_any.
PiperOrigin-RevId: 210507220
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 27 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator_test.cc | 27 |
2 files changed, 44 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index e9383098cc..f687e9689e 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -769,7 +769,7 @@ class Sum }; class ReduceMax - : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, + : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions, ::tflite::BuiltinOptions_ReducerOptions> { public: using BuiltinOperator::BuiltinOperator; @@ -788,7 +788,7 @@ class ReduceMax }; class ReduceMin - : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, + : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions, ::tflite::BuiltinOptions_ReducerOptions> { public: using BuiltinOperator::BuiltinOperator; @@ -807,7 +807,26 @@ class ReduceMin }; class ReduceProd - : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions, + : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions, + ::tflite::BuiltinOptions_ReducerOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateReducerOptions(*builder, op.keep_dims); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->keep_dims = options.keep_dims(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + +class ReduceAny + : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions, ::tflite::BuiltinOptions_ReducerOptions> { public: using BuiltinOperator::BuiltinOperator; @@ -1336,6 +1355,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { OperatorType::kReduceMax)); ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN, OperatorType::kReduceMin)); + ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY, + OperatorType::kAny)); ops.push_back( MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR, OperatorType::kResizeBilinear)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index bb0b457483..6da9317e4f 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -97,6 +97,16 @@ class OperatorTest : public ::testing::Test { ASSERT_NE(nullptr, output_toco_op.get()); } + + template <typename T> + void CheckReducerOperator(const string& name, OperatorType type) { + T op; + + op.keep_dims = false; + + auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op); + EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); + } }; TEST_F(OperatorTest, SimpleOperators) { @@ -144,13 +154,16 @@ TEST_F(OperatorTest, BuiltinAdd) { output_toco_op->fused_activation_function); } -TEST_F(OperatorTest, BuiltinMean) { - MeanOperator op; - op.keep_dims = false; - - auto output_toco_op = - SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op); - EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims); +TEST_F(OperatorTest, BuiltinReducerOps) { + CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean); + CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum); + CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD", + OperatorType::kReduceProd); + CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX", + OperatorType::kReduceMax); + CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN", + OperatorType::kReduceMin); + CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny); } TEST_F(OperatorTest, BuiltinCast) { |