diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-27 00:20:43 -0700 |
---|---|---|
committer | Gunhan Gulsoy <gunan@google.com> | 2018-06-28 21:37:43 -0700 |
commit | 51c80b60492e8095999d1f1194c8d56e6d222719 (patch) | |
tree | 5bd267278b7d71ca816a285a633c2cf140b83ae1 /tensorflow/contrib/lite | |
parent | 11157efc4e94a7c70ff7532d7bb835fb5d9d19da (diff) |
Implementation of pow.
PiperOrigin-RevId: 202262513
Diffstat (limited to 'tensorflow/contrib/lite')
21 files changed, 492 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 81883ba1fd..5543acc1f5 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -233,6 +233,7 @@ def generated_test_models(): "pad", "padv2", # "prelu", + "pow", "relu", "relu1", "relu6", diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 7a78206ebf..a44e918230 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -103,6 +103,7 @@ typedef enum { kTfLiteBuiltinSqrt = 75, kTfLiteBuiltinRsqrt = 76, kTfLiteBuiltinShape = 77, + kTfLiteBuiltinPow = 78, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 45104c1419..dcd17bbeab 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -778,6 +778,18 @@ Outputs { } ``` +**POW** + +``` +Inputs { + 0: a tensor + 1: a tensor +} +Outputs { + 0: elementwise pow of the input tensors +} +``` + And these are TensorFlow Lite operations that are present but not ready for custom models yet: diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index a77897a173..61d5af3478 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -163,6 +163,7 @@ cc_library( "neg.cc", "pad.cc", "pooling.cc", + "pow.cc", "reduce.cc", "register.cc", "reshape.cc", @@ -1009,6 +1010,20 @@ tf_cc_test( ], ) +tf_cc_test( + name = "pow_test", + size = "small", + srcs = ["pow_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index e4653123f6..089e743b1f 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4070,6 +4070,36 @@ inline void SparseToDense(const std::vector<std::vector<I>>& indices, } } +template <typename T> +inline void Pow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims); + for (int i = 0; i < flat_size; ++i) { + output_data[i] = std::pow(input1_data[i], input2_data[i]); + } +} + +template <typename T> +inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + NdArrayDesc<4> desc1; + NdArrayDesc<4> desc2; + NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2); + for (int b = 0; b < ArraySize(output_dims, 3); ++b) { + for (int y = 0; y < ArraySize(output_dims, 2); ++y) { + for (int x = 0; x < ArraySize(output_dims, 1); ++x) { + for (int c = 0; c < ArraySize(output_dims, 0); ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)], + input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + } + } + } + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc new file mode 100644 index 0000000000..4a539c47a8 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow.cc @@ -0,0 +1,143 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pow { +namespace { + +// Input/output tensor index. +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +// Op data for pow op. +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast<OpData*>(buffer); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpData* data = reinterpret_cast<OpData*>(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + + const TfLiteType type = input1->type; + if (type != kTfLiteInt32 && type != kTfLiteFloat32) { + context->ReportError(context, "Unsupported data type %d.", type); + return kTfLiteError; + } + output->type = type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } + + return context->ResizeTensor(context, output, output_size); +} + +template <typename T> +void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2, + TfLiteTensor* output, bool requires_broadcast) { + if (requires_broadcast) { + reference_ops::BroadcastPow(GetTensorData<T>(input1), GetTensorDims(input1), + GetTensorData<T>(input2), GetTensorDims(input2), + GetTensorData<T>(output), + GetTensorDims(output)); + } else { + reference_ops::Pow(GetTensorData<T>(input1), GetTensorDims(input1), + GetTensorData<T>(input2), GetTensorDims(input2), + GetTensorData<T>(output), GetTensorDims(output)); + } +} + +TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) { + const int64_t num_elements = NumElements(input); + const int32_t* data = GetTensorData<int32_t>(input); + for (int i = 0; i < num_elements; ++i) { + if (data[i] < 0) { + context->ReportError(context, + "POW does not support negative value for int32."); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast<OpData*>(node->user_data); + + const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + switch (output->type) { + case kTfLiteInt32: { + // TensorFlow does not support negative for int32. + TF_LITE_ENSURE_OK(context, CheckValue(context, input2)); + PowImpl<int32_t>(input1, input2, output, data->requires_broadcast); + break; + } + case kTfLiteFloat32: { + PowImpl<float>(input1, input2, output, data->requires_broadcast); + break; + } + default: { + context->ReportError(context, "Unsupported data type: %d", output->type); + return kTfLiteError; + } + } + return kTfLiteOk; +} + +} // namespace +} // namespace pow + +TfLiteRegistration* Register_POW() { + static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc new file mode 100644 index 0000000000..474d323bc3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pow_test.cc @@ -0,0 +1,117 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <gtest/gtest.h> +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAre; +using ::testing::ElementsAreArray; + +template <typename T> +class PowOpModel : public SingleOpModel { + public: + PowOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output) { + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_POW, BuiltinOptions_PowOptions, + CreatePowOptions(builder_).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector<T> GetOutput() { return ExtractVector<T>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input1_; + int input2_; + int output_; +}; + +TEST(PowOpModel, Simple) { + PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 1}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8)); +} + +TEST(PowOpModel, NegativeAndZeroValue) { + PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {}}); + model.PopulateTensor<int32>(model.input1(), {0, 2, -7, 8}); + model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 0}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1)); +} + +TEST(PowOpModel, Float) { + PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor<float>(model.input2(), {0.5, 2.7, 3.1, 3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 0.08424846, 0.33098164, 277.313}, 1e-3))); +} + +TEST(PowOpModel, NegativeFloatTest) { + PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8}); + model.PopulateTensor<float>(model.input2(), {0.5, -2.7, 3.1, -3.2}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.5477226, 11.869653, 0.33098164, 0.003606}, 1e-3))); +} + +TEST(PowOpModel, BroadcastTest) { + PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}}, + {TensorType_INT32, {1}}, {TensorType_INT32, {}}); + model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8}); + model.PopulateTensor<int32>(model.input2(), {4}); + model.Invoke(); + EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1)); + EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096)); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index f04fdc67c0..0ca08cd8f3 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -101,6 +101,7 @@ TfLiteRegistration* Register_NOT_EQUAL(); TfLiteRegistration* Register_SQRT(); TfLiteRegistration* Register_RSQRT(); TfLiteRegistration* Register_SHAPE(); +TfLiteRegistration* Register_POW(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -185,6 +186,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SQRT, Register_SQRT()); AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT()); AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE()); + AddBuiltin(BuiltinOperator_POW, Register_POW()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 04327a44db..793a72272d 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -735,6 +735,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_TILE: case BuiltinOperator_TOPK_V2: case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_POW: break; } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index ab007993af..748c2f1a04 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -504,6 +504,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SQRT: case tflite::BuiltinOperator_RSQRT: case tflite::BuiltinOperator_SHAPE: + case tflite::BuiltinOperator_POW: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 5a53ef124d..76ad3ef893 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -158,6 +158,7 @@ enum BuiltinOperator : byte { SQRT = 75, RSQRT = 76, SHAPE = 77, + POW = 78, } // Options for the builtin operators. @@ -217,6 +218,7 @@ union BuiltinOptions { EqualOptions, NotEqualOptions, ShapeOptions, + PowOptions, } enum Padding : byte { SAME, VALID } @@ -511,6 +513,9 @@ table ShapeOptions { out_type : TensorType; } +table PowOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 0b8c6387c2..e3ce90aa55 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -196,6 +196,9 @@ struct NotEqualOptionsT; struct ShapeOptions; struct ShapeOptionsT; +struct PowOptions; +struct PowOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -336,11 +339,12 @@ enum BuiltinOperator { BuiltinOperator_SQRT = 75, BuiltinOperator_RSQRT = 76, BuiltinOperator_SHAPE = 77, + BuiltinOperator_POW = 78, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_SHAPE + BuiltinOperator_MAX = BuiltinOperator_POW }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[77] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -418,7 +422,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[77] { BuiltinOperator_SUM, BuiltinOperator_SQRT, BuiltinOperator_RSQRT, - BuiltinOperator_SHAPE + BuiltinOperator_SHAPE, + BuiltinOperator_POW }; return values; } @@ -503,6 +508,7 @@ inline const char **EnumNamesBuiltinOperator() { "SQRT", "RSQRT", "SHAPE", + "POW", nullptr }; return names; @@ -570,11 +576,12 @@ enum BuiltinOptions { BuiltinOptions_EqualOptions = 53, BuiltinOptions_NotEqualOptions = 54, BuiltinOptions_ShapeOptions = 55, + BuiltinOptions_PowOptions = 56, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_ShapeOptions + BuiltinOptions_MAX = BuiltinOptions_PowOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[56] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -631,7 +638,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[56] { BuiltinOptions_ExpandDimsOptions, BuiltinOptions_EqualOptions, BuiltinOptions_NotEqualOptions, - BuiltinOptions_ShapeOptions + BuiltinOptions_ShapeOptions, + BuiltinOptions_PowOptions }; return values; } @@ -694,6 +702,7 @@ inline const char **EnumNamesBuiltinOptions() { "EqualOptions", "NotEqualOptions", "ShapeOptions", + "PowOptions", nullptr }; return names; @@ -928,6 +937,10 @@ template<> struct BuiltinOptionsTraits<ShapeOptions> { static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions; }; +template<> struct BuiltinOptionsTraits<PowOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_PowOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1399,6 +1412,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_ShapeOptions ? reinterpret_cast<const ShapeOptionsT *>(value) : nullptr; } + PowOptionsT *AsPowOptions() { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast<PowOptionsT *>(value) : nullptr; + } + const PowOptionsT *AsPowOptions() const { + return type == BuiltinOptions_PowOptions ? + reinterpret_cast<const PowOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -5048,6 +5069,46 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions( flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct PowOptionsT : public flatbuffers::NativeTable { + typedef PowOptions TableType; + PowOptionsT() { + } +}; + +struct PowOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef PowOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + PowOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<PowOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct PowOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit PowOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + PowOptionsBuilder &operator=(const PowOptionsBuilder &); + flatbuffers::Offset<PowOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<PowOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<PowOptions> CreatePowOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + PowOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -5346,6 +5407,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const ShapeOptions *builtin_options_as_ShapeOptions() const { return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast<const ShapeOptions *>(builtin_options()) : nullptr; } + const PowOptions *builtin_options_as_PowOptions() const { + return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast<const PowOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -5597,6 +5661,10 @@ template<> inline const ShapeOptions *Operator::builtin_options_as<ShapeOptions> return builtin_options_as_ShapeOptions(); } +template<> inline const PowOptions *Operator::builtin_options_as<PowOptions>() const { + return builtin_options_as_PowOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -7576,6 +7644,29 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBuf _out_type); } +inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new PowOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void PowOptions::UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset<PowOptions> PowOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreatePowOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PowOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePowOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7985,6 +8076,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const ShapeOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast<const PowOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -8223,6 +8318,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const ShapeOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast<const PowOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -8449,6 +8548,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const ShapeOptionsT *>(value); return CreateShapeOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast<const PowOptionsT *>(value); + return CreatePowOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -8675,6 +8778,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new ShapeOptionsT(*reinterpret_cast<ShapeOptionsT *>(u.value)); break; } + case BuiltinOptions_PowOptions: { + value = new PowOptionsT(*reinterpret_cast<PowOptionsT *>(u.value)); + break; + } default: break; } @@ -8957,6 +9064,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_PowOptions: { + auto ptr = reinterpret_cast<PowOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index d62a311c59..1360f1a273 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -990,6 +990,10 @@ def make_mul_tests(zip_path): make_binary_op_tests(zip_path, tf.multiply) +def make_pow_tests(zip_path): + make_binary_op_tests(zip_path, tf.pow) + + def make_gather_tests(zip_path): """Make a set of tests to do gather.""" diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 6b78f1c05e..48febfb301 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1773,6 +1773,20 @@ void ConvertSparseToDenseOperator(const Model& model, src_op.validate_indices); } +void ConvertPowOperator(const Model& model, const PowOperator& src_op, + const char* op_name, GraphDef* tensorflow_graph) { + tensorflow::NodeDef* pow_op = tensorflow_graph->add_node(); + pow_op->set_op(op_name); + pow_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + for (int i = 0; i < 2; ++i) { + *pow_op->add_input() = src_op.inputs[i]; + } + const tensorflow::DataType data_type = + GetTensorFlowDataType(model, src_op.inputs[0]); + (*pow_op->mutable_attr())["T"].set_type(data_type); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1987,6 +2001,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, ConvertTileOperator(model, static_cast<const TensorFlowTileOperator&>(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kPow) { + ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow", + tensorflow_graph); } else { LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type); } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 27a1049eaf..00ab7cbaa9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -175,6 +175,14 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type); break; } + case OperatorType::kPow: { + CHECK_EQ(op->inputs.size(), 2); + CHECK(model->GetArray(op->inputs[0]).data_type == + model->GetArray(op->inputs[1]).data_type); + const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type; + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index c61da203c6..c9c9f13d2e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1611,6 +1611,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kGreaterEqual: case OperatorType::kEqual: case OperatorType::kNotEqual: + case OperatorType::kPow: ProcessSimpleBinaryOperator(model, op); break; case OperatorType::kAddN: diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 485e853e25..55e39d963f 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1899,6 +1899,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"ParallelDynamicStitch", ConvertDynamicStitchOperator}, {"Placeholder", ConvertPlaceholderOperator}, {"PlaceholderWithDefault", ConvertIdentityOperator}, + {"Pow", ConvertSimpleOperator<PowOperator, 2>}, {"RandomUniform", ConvertRandomUniform}, {"Range", ConvertRangeOperator}, {"Rank", ConvertSimpleOperator<RankOperator, 1>}, diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 89cb061499..aa05a9bd0e 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -138,6 +138,7 @@ enum class OperatorType : uint8 { kSparseToDense, kEqual, kNotEqual, + kPow, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1637,6 +1638,17 @@ struct SparseToDenseOperator : Operator { bool validate_indices; }; +// Pow operator: +// +// Inputs: +// Inputs[0]: required: A tensor. +// Inputs[1]: required: A tensor. +// +// TensorFlow equivalent: Pow. +struct PowOperator : Operator { + PowOperator() : Operator(OperatorType::kPow) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 2d7a4a7a4c..7e55ae92bd 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -1237,6 +1237,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() { new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect)); ops.emplace_back( new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice)); + ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow)); // Element-wise operator ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin)); ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog)); diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 79c8e5d738..8b6808d3c7 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -126,6 +126,7 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog); CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt); CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt); + CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow); } TEST_F(OperatorTest, BuiltinAdd) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 3d9fa732bd..7dc1af9f1d 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -396,6 +396,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(SparseToDense) HANDLE_OPERATORTYPENAME_CASE(Equal) HANDLE_OPERATORTYPENAME_CASE(NotEqual) + HANDLE_OPERATORTYPENAME_CASE(Pow) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE |