aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-27 00:20:43 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit51c80b60492e8095999d1f1194c8d56e6d222719 (patch)
tree5bd267278b7d71ca816a285a633c2cf140b83ae1 /tensorflow/contrib/lite
parent11157efc4e94a7c70ff7532d7bb835fb5d9d19da (diff)
Implementation of pow.
PiperOrigin-RevId: 202262513
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md12
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD15
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h30
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc143
-rw-r--r--tensorflow/contrib/lite/kernels/pow_test.cc117
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc1
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs5
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h124
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py4
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc17
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc1
-rw-r--r--tensorflow/contrib/lite/toco/model.h12
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
21 files changed, 492 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 81883ba1fd..5543acc1f5 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -233,6 +233,7 @@ def generated_test_models():
"pad",
"padv2",
# "prelu",
+ "pow",
"relu",
"relu1",
"relu6",
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 7a78206ebf..a44e918230 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -103,6 +103,7 @@ typedef enum {
kTfLiteBuiltinSqrt = 75,
kTfLiteBuiltinRsqrt = 76,
kTfLiteBuiltinShape = 77,
+ kTfLiteBuiltinPow = 78,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 45104c1419..dcd17bbeab 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -778,6 +778,18 @@ Outputs {
}
```
+**POW**
+
+```
+Inputs {
+ 0: a tensor
+ 1: a tensor
+}
+Outputs {
+ 0: elementwise pow of the input tensors
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index a77897a173..61d5af3478 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -163,6 +163,7 @@ cc_library(
"neg.cc",
"pad.cc",
"pooling.cc",
+ "pow.cc",
"reduce.cc",
"register.cc",
"reshape.cc",
@@ -1009,6 +1010,20 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "pow_test",
+ size = "small",
+ srcs = ["pow_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index e4653123f6..089e743b1f 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -4070,6 +4070,36 @@ inline void SparseToDense(const std::vector<std::vector<I>>& indices,
}
}
+template <typename T>
+inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = std::pow(input1_data[i], input2_data[i]);
+ }
+}
+
+template <typename T>
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ NdArrayDesc<4> desc1;
+ NdArrayDesc<4> desc2;
+ NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+ for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+ for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+ for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ output_data[Offset(output_dims, c, x, y, b)] =
+ std::pow(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
+ input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ }
+ }
+ }
+ }
+}
+
} // namespace reference_ops
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
new file mode 100644
index 0000000000..4a539c47a8
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -0,0 +1,143 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace pow {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for pow op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
+ context->ReportError(context, "Unsupported data type %d.", type);
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
+ TfLiteTensor* output, bool requires_broadcast) {
+ if (requires_broadcast) {
+ reference_ops::BroadcastPow(GetTensorData<T>(input1), GetTensorDims(input1),
+ GetTensorData<T>(input2), GetTensorDims(input2),
+ GetTensorData<T>(output),
+ GetTensorDims(output));
+ } else {
+ reference_ops::Pow(GetTensorData<T>(input1), GetTensorDims(input1),
+ GetTensorData<T>(input2), GetTensorDims(input2),
+ GetTensorData<T>(output), GetTensorDims(output));
+ }
+}
+
+TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
+ const int64_t num_elements = NumElements(input);
+ const int32_t* data = GetTensorData<int32_t>(input);
+ for (int i = 0; i < num_elements; ++i) {
+ if (data[i] < 0) {
+ context->ReportError(context,
+ "POW does not support negative value for int32.");
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (output->type) {
+ case kTfLiteInt32: {
+ // TensorFlow does not support negative for int32.
+ TF_LITE_ENSURE_OK(context, CheckValue(context, input2));
+ PowImpl<int32_t>(input1, input2, output, data->requires_broadcast);
+ break;
+ }
+ case kTfLiteFloat32: {
+ PowImpl<float>(input1, input2, output, data->requires_broadcast);
+ break;
+ }
+ default: {
+ context->ReportError(context, "Unsupported data type: %d", output->type);
+ return kTfLiteError;
+ }
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+} // namespace pow
+
+TfLiteRegistration* Register_POW() {
+ static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/pow_test.cc b/tensorflow/contrib/lite/kernels/pow_test.cc
new file mode 100644
index 0000000000..474d323bc3
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/pow_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class PowOpModel : public SingleOpModel {
+ public:
+ PowOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_POW, BuiltinOptions_PowOptions,
+ CreatePowOptions(builder_).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(PowOpModel, Simple) {
+ PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 1});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(12, 4, 343, 8));
+}
+
+TEST(PowOpModel, NegativeAndZeroValue) {
+ PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32>(model.input1(), {0, 2, -7, 8});
+ model.PopulateTensor<int32>(model.input2(), {1, 2, 3, 0});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(0, 4, -343, 1));
+}
+
+TEST(PowOpModel, Float) {
+ PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+ model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8});
+ model.PopulateTensor<float>(model.input2(), {0.5, 2.7, 3.1, 3.2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.5477226, 0.08424846, 0.33098164, 277.313}, 1e-3)));
+}
+
+TEST(PowOpModel, NegativeFloatTest) {
+ PowOpModel<float> model({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}});
+ model.PopulateTensor<float>(model.input1(), {0.3, 0.4, 0.7, 5.8});
+ model.PopulateTensor<float>(model.input2(), {0.5, -2.7, 3.1, -3.2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {0.5477226, 11.869653, 0.33098164, 0.003606}, 1e-3)));
+}
+
+TEST(PowOpModel, BroadcastTest) {
+ PowOpModel<int32> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32>(model.input1(), {12, 2, 7, 8});
+ model.PopulateTensor<int32>(model.input2(), {4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(20736, 16, 2401, 4096));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index f04fdc67c0..0ca08cd8f3 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -101,6 +101,7 @@ TfLiteRegistration* Register_NOT_EQUAL();
TfLiteRegistration* Register_SQRT();
TfLiteRegistration* Register_RSQRT();
TfLiteRegistration* Register_SHAPE();
+TfLiteRegistration* Register_POW();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -185,6 +186,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_SQRT, Register_SQRT());
AddBuiltin(BuiltinOperator_RSQRT, Register_RSQRT());
AddBuiltin(BuiltinOperator_SHAPE, Register_SHAPE());
+ AddBuiltin(BuiltinOperator_POW, Register_POW());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 04327a44db..793a72272d 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -735,6 +735,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_TILE:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_TRANSPOSE:
+ case BuiltinOperator_POW:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index ab007993af..748c2f1a04 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -504,6 +504,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_SQRT:
case tflite::BuiltinOperator_RSQRT:
case tflite::BuiltinOperator_SHAPE:
+ case tflite::BuiltinOperator_POW:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 5a53ef124d..76ad3ef893 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -158,6 +158,7 @@ enum BuiltinOperator : byte {
SQRT = 75,
RSQRT = 76,
SHAPE = 77,
+ POW = 78,
}
// Options for the builtin operators.
@@ -217,6 +218,7 @@ union BuiltinOptions {
EqualOptions,
NotEqualOptions,
ShapeOptions,
+ PowOptions,
}
enum Padding : byte { SAME, VALID }
@@ -511,6 +513,9 @@ table ShapeOptions {
out_type : TensorType;
}
+table PowOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 0b8c6387c2..e3ce90aa55 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -196,6 +196,9 @@ struct NotEqualOptionsT;
struct ShapeOptions;
struct ShapeOptionsT;
+struct PowOptions;
+struct PowOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -336,11 +339,12 @@ enum BuiltinOperator {
BuiltinOperator_SQRT = 75,
BuiltinOperator_RSQRT = 76,
BuiltinOperator_SHAPE = 77,
+ BuiltinOperator_POW = 78,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_SHAPE
+ BuiltinOperator_MAX = BuiltinOperator_POW
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[77] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[78] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -418,7 +422,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[77] {
BuiltinOperator_SUM,
BuiltinOperator_SQRT,
BuiltinOperator_RSQRT,
- BuiltinOperator_SHAPE
+ BuiltinOperator_SHAPE,
+ BuiltinOperator_POW
};
return values;
}
@@ -503,6 +508,7 @@ inline const char **EnumNamesBuiltinOperator() {
"SQRT",
"RSQRT",
"SHAPE",
+ "POW",
nullptr
};
return names;
@@ -570,11 +576,12 @@ enum BuiltinOptions {
BuiltinOptions_EqualOptions = 53,
BuiltinOptions_NotEqualOptions = 54,
BuiltinOptions_ShapeOptions = 55,
+ BuiltinOptions_PowOptions = 56,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_ShapeOptions
+ BuiltinOptions_MAX = BuiltinOptions_PowOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[56] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[57] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -631,7 +638,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[56] {
BuiltinOptions_ExpandDimsOptions,
BuiltinOptions_EqualOptions,
BuiltinOptions_NotEqualOptions,
- BuiltinOptions_ShapeOptions
+ BuiltinOptions_ShapeOptions,
+ BuiltinOptions_PowOptions
};
return values;
}
@@ -694,6 +702,7 @@ inline const char **EnumNamesBuiltinOptions() {
"EqualOptions",
"NotEqualOptions",
"ShapeOptions",
+ "PowOptions",
nullptr
};
return names;
@@ -928,6 +937,10 @@ template<> struct BuiltinOptionsTraits<ShapeOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_ShapeOptions;
};
+template<> struct BuiltinOptionsTraits<PowOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_PowOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1399,6 +1412,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_ShapeOptions ?
reinterpret_cast<const ShapeOptionsT *>(value) : nullptr;
}
+ PowOptionsT *AsPowOptions() {
+ return type == BuiltinOptions_PowOptions ?
+ reinterpret_cast<PowOptionsT *>(value) : nullptr;
+ }
+ const PowOptionsT *AsPowOptions() const {
+ return type == BuiltinOptions_PowOptions ?
+ reinterpret_cast<const PowOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5048,6 +5069,46 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(
flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ShapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct PowOptionsT : public flatbuffers::NativeTable {
+ typedef PowOptions TableType;
+ PowOptionsT() {
+ }
+};
+
+struct PowOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef PowOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ PowOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<PowOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct PowOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit PowOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ PowOptionsBuilder &operator=(const PowOptionsBuilder &);
+ flatbuffers::Offset<PowOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<PowOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<PowOptions> CreatePowOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ PowOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5346,6 +5407,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const ShapeOptions *builtin_options_as_ShapeOptions() const {
return builtin_options_type() == BuiltinOptions_ShapeOptions ? static_cast<const ShapeOptions *>(builtin_options()) : nullptr;
}
+ const PowOptions *builtin_options_as_PowOptions() const {
+ return builtin_options_type() == BuiltinOptions_PowOptions ? static_cast<const PowOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -5597,6 +5661,10 @@ template<> inline const ShapeOptions *Operator::builtin_options_as<ShapeOptions>
return builtin_options_as_ShapeOptions();
}
+template<> inline const PowOptions *Operator::builtin_options_as<PowOptions>() const {
+ return builtin_options_as_PowOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7576,6 +7644,29 @@ inline flatbuffers::Offset<ShapeOptions> CreateShapeOptions(flatbuffers::FlatBuf
_out_type);
}
+inline PowOptionsT *PowOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new PowOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void PowOptions::UnPackTo(PowOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<PowOptions> PowOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreatePowOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<PowOptions> CreatePowOptions(flatbuffers::FlatBufferBuilder &_fbb, const PowOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PowOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreatePowOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -7985,6 +8076,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const ShapeOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<const PowOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8223,6 +8318,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const ShapeOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<const PowOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -8449,6 +8548,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const ShapeOptionsT *>(value);
return CreateShapeOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<const PowOptionsT *>(value);
+ return CreatePowOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -8675,6 +8778,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new ShapeOptionsT(*reinterpret_cast<ShapeOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_PowOptions: {
+ value = new PowOptionsT(*reinterpret_cast<PowOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -8957,6 +9064,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_PowOptions: {
+ auto ptr = reinterpret_cast<PowOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index d62a311c59..1360f1a273 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -990,6 +990,10 @@ def make_mul_tests(zip_path):
make_binary_op_tests(zip_path, tf.multiply)
+def make_pow_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.pow)
+
+
def make_gather_tests(zip_path):
"""Make a set of tests to do gather."""
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index 6b78f1c05e..48febfb301 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1773,6 +1773,20 @@ void ConvertSparseToDenseOperator(const Model& model,
src_op.validate_indices);
}
+void ConvertPowOperator(const Model& model, const PowOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* pow_op = tensorflow_graph->add_node();
+ pow_op->set_op(op_name);
+ pow_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ for (int i = 0; i < 2; ++i) {
+ *pow_op->add_input() = src_op.inputs[i];
+ }
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*pow_op->mutable_attr())["T"].set_type(data_type);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -1987,6 +2001,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertTileOperator(model,
static_cast<const TensorFlowTileOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kPow) {
+ ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 27a1049eaf..00ab7cbaa9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -175,6 +175,14 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
+ case OperatorType::kPow: {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK(model->GetArray(op->inputs[0]).data_type ==
+ model->GetArray(op->inputs[1]).data_type);
+ const ArrayDataType data_type = model->GetArray(op->inputs[0]).data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index c61da203c6..c9c9f13d2e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1611,6 +1611,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kGreaterEqual:
case OperatorType::kEqual:
case OperatorType::kNotEqual:
+ case OperatorType::kPow:
ProcessSimpleBinaryOperator(model, op);
break;
case OperatorType::kAddN:
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 485e853e25..55e39d963f 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1899,6 +1899,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"ParallelDynamicStitch", ConvertDynamicStitchOperator},
{"Placeholder", ConvertPlaceholderOperator},
{"PlaceholderWithDefault", ConvertIdentityOperator},
+ {"Pow", ConvertSimpleOperator<PowOperator, 2>},
{"RandomUniform", ConvertRandomUniform},
{"Range", ConvertRangeOperator},
{"Rank", ConvertSimpleOperator<RankOperator, 1>},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 89cb061499..aa05a9bd0e 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -138,6 +138,7 @@ enum class OperatorType : uint8 {
kSparseToDense,
kEqual,
kNotEqual,
+ kPow,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1637,6 +1638,17 @@ struct SparseToDenseOperator : Operator {
bool validate_indices;
};
+// Pow operator:
+//
+// Inputs:
+// Inputs[0]: required: A tensor.
+// Inputs[1]: required: A tensor.
+//
+// TensorFlow equivalent: Pow.
+struct PowOperator : Operator {
+ PowOperator() : Operator(OperatorType::kPow) {}
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 2d7a4a7a4c..7e55ae92bd 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1237,6 +1237,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
new SimpleOperator<SelectOperator>("SELECT", OperatorType::kSelect));
ops.emplace_back(
new SimpleOperator<SliceOperator>("SLICE", OperatorType::kSlice));
+ ops.emplace_back(new SimpleOperator<PowOperator>("POW", OperatorType::kPow));
// Element-wise operator
ops.emplace_back(new SimpleOperator<SinOperator>("SIN", OperatorType::kSin));
ops.emplace_back(new SimpleOperator<LogOperator>("LOG", OperatorType::kLog));
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 79c8e5d738..8b6808d3c7 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -126,6 +126,7 @@ TEST_F(OperatorTest, SimpleOperators) {
CheckSimpleOperator<LogOperator>("LOG", OperatorType::kLog);
CheckSimpleOperator<TensorFlowSqrtOperator>("SQRT", OperatorType::kSqrt);
CheckSimpleOperator<TensorFlowRsqrtOperator>("RSQRT", OperatorType::kRsqrt);
+ CheckSimpleOperator<PowOperator>("POW", OperatorType::kPow);
}
TEST_F(OperatorTest, BuiltinAdd) {
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 3d9fa732bd..7dc1af9f1d 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -396,6 +396,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(SparseToDense)
HANDLE_OPERATORTYPENAME_CASE(Equal)
HANDLE_OPERATORTYPENAME_CASE(NotEqual)
+ HANDLE_OPERATORTYPENAME_CASE(Pow)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE