aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-07-26 10:53:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 10:56:43 -0700
commit6e658c0a5ca77677a954a34fb98f241c592c970d (patch)
treeb645103887539af5232b3f70d80a2eb9b77ed63a /tensorflow/contrib/lite
parent0a3155f7fbf56df5e81c7cbf35afd45173359635 (diff)
Add one_hot op support to TFLite
PiperOrigin-RevId: 206185190
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/build_def.bzl1
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h4
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h1
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md1
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD14
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot.cc199
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot_test.cc182
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc8
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs6
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h141
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py63
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc17
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc58
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc22
-rw-r--r--tensorflow/contrib/lite/toco/model.h22
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc19
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc8
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
21 files changed, 775 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index a8a49784c6..256334576c 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -248,6 +248,7 @@ def generated_test_models():
"mul",
"neg",
"not_equal",
+ "one_hot",
"pack",
"pad",
"padv2",
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index fd16aa1063..70178b2faa 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -282,6 +282,10 @@ typedef struct {
int axis;
} TfLitePackParams;
+typedef struct {
+ int axis;
+} TfLiteOneHotParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 1ae73b9738..0b6568fd2f 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -110,6 +110,7 @@ typedef enum {
kTfLiteBuiltinReduceMax = 82,
kTfLiteBuiltinPack = 83,
kTfLiteBuiltinLogicalOr = 84,
+ kTfLiteBuiltinOneHot = 85,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index 0e8f4339fc..2ea7aeaa5d 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -62,6 +62,7 @@ counterparts:
* [tf.nn.softmax](https://www.tensorflow.org/api_docs/python/tf/nn/softmax) -
*as long as tensors are 2D and axis is the last dimension*
* [tf.nn.top_k](https://www.tensorflow.org/api_docs/python/tf/nn/top_k)
+* [tf.one_hot](https://www.tensorflow.org/api_docs/python/tf/one_hot)
* [tf.pad](https://www.tensorflow.org/api_docs/python/tf/pad) - *as long as
mode and constant_values are not used*
* [tf.reduce_mean](https://www.tensorflow.org/api_docs/python/tf/reduce_mean) -
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index c224132cae..026ab4de03 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -176,6 +176,7 @@ cc_library(
"mfcc.cc",
"mul.cc",
"neg.cc",
+ "one_hot.cc",
"pack.cc",
"pad.cc",
"pooling.cc",
@@ -1171,6 +1172,19 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "one_hot_test",
+ size = "small",
+ srcs = ["one_hot_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
new file mode 100644
index 0000000000..9ff3dca932
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -0,0 +1,199 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace one_hot {
+
+constexpr int kIndicesTensor = 0;
+constexpr int kDepthTensor = 1;
+constexpr int kOnValueTensor = 2;
+constexpr int kOffValueTensor = 3;
+constexpr int kOutputTensor = 0;
+
+// Convenience utility for destructuring a node into the appropriate tensors and
+// data for the op. Note that this destructuring is quite cheap, so we can avoid
+// allocating op-specific, persistent data on the heap.
+struct OneHotContext {
+ OneHotContext(TfLiteContext* context, TfLiteNode* node) {
+ indices = GetInput(context, node, kIndicesTensor);
+ depth = GetInput(context, node, kDepthTensor);
+ on_value = GetInput(context, node, kOnValueTensor);
+ off_value = GetInput(context, node, kOffValueTensor);
+ output = GetOutput(context, node, kOutputTensor);
+
+ const auto* params =
+ reinterpret_cast<TfLiteOneHotParams*>(node->builtin_data);
+ const int indices_dims = indices->dims->size;
+ axis = (params->axis == -1) ? indices_dims : params->axis;
+ output_dims = indices_dims + 1;
+ dtype = on_value->type;
+ }
+
+ const TfLiteTensor* indices;
+ const TfLiteTensor* depth;
+ const TfLiteTensor* on_value;
+ const TfLiteTensor* off_value;
+ TfLiteTensor* output;
+ int axis;
+ int output_dims;
+ TfLiteType dtype;
+};
+
+template <typename T, typename TI>
+void OneHotComputeImpl(const OneHotContext& op_context) {
+ // prefix_dim_size == # of elements before the axis
+ // depth == # of elements per axis
+ // suffix_dim_size == # of elements after the axis
+ int prefix_dim_size = 1;
+ for (int i = 0; i < op_context.axis; ++i) {
+ prefix_dim_size *= op_context.indices->dims->data[i];
+ }
+ const int suffix_dim_size = NumElements(op_context.indices) / prefix_dim_size;
+ const int depth = *op_context.depth->data.i32;
+
+ const T on_value = *GetTensorData<T>(op_context.on_value);
+ const T off_value = *GetTensorData<T>(op_context.off_value);
+
+ // View the indices as a matrix of size:
+ // prefix_dim_size x suffix_dim_size
+ // View the output as a matrix of size:
+ // prefix_dim_size x depth x suffix_dim_size
+ // Then the output is:
+ // output(i, j, k) == (indices(i, k) == j) ? on : off
+ T* output = GetTensorData<T>(op_context.output);
+ const TI* indices = GetTensorData<TI>(op_context.indices);
+ for (int i = 0; i < prefix_dim_size; ++i) {
+ for (int j = 0; j < depth; ++j) {
+ for (int k = 0; k < suffix_dim_size; ++k, ++output) {
+ *output = static_cast<int>(indices[i * suffix_dim_size + k]) == j
+ ? on_value
+ : off_value;
+ }
+ }
+ }
+}
+
+template <typename T>
+void OneHotCompute(const OneHotContext& op_context) {
+ if (op_context.indices->type == kTfLiteInt64) {
+ OneHotComputeImpl<T, int64_t>(op_context);
+ } else {
+ OneHotComputeImpl<T, int>(op_context);
+ }
+}
+
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
+ const OneHotContext& op_context) {
+ TF_LITE_ENSURE(context, *op_context.depth->data.i32 >= 0);
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(op_context.output_dims);
+ for (int i = 0; i < op_context.output_dims; ++i) {
+ if (i < op_context.axis) {
+ output_size->data[i] = op_context.indices->dims->data[i];
+ } else if (i == op_context.axis) {
+ output_size->data[i] = *op_context.depth->data.i32;
+ } else {
+ output_size->data[i] = op_context.indices->dims->data[i - 1];
+ }
+ }
+ return context->ResizeTensor(context, op_context.output, output_size);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ OneHotContext op_context{context, node};
+ switch (op_context.dtype) {
+ // TODO(b/111744875): Support uint8 and quantization.
+ case kTfLiteFloat32:
+ case kTfLiteInt16:
+ case kTfLiteInt32:
+ case kTfLiteInt64:
+ case kTfLiteBool:
+ op_context.output->type = op_context.dtype;
+ break;
+ default:
+ context->ReportError(context, "Unknown output data type: %d",
+ op_context.dtype);
+ return kTfLiteError;
+ }
+
+ TF_LITE_ENSURE(context, op_context.indices->type == kTfLiteInt32 ||
+ op_context.indices->type == kTfLiteInt64);
+ TF_LITE_ENSURE(context, op_context.axis >= 0 &&
+ op_context.axis < op_context.output_dims);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.depth), 1);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.on_value), 1);
+ TF_LITE_ENSURE_EQ(context, NumElements(op_context.off_value), 1);
+ TF_LITE_ENSURE_EQ(context, op_context.on_value->type, op_context.dtype);
+ TF_LITE_ENSURE_EQ(context, op_context.off_value->type, op_context.dtype);
+
+ if (!IsConstantTensor(op_context.depth)) {
+ SetTensorToDynamic(op_context.output);
+ return kTfLiteOk;
+ }
+
+ return ResizeOutputTensor(context, op_context);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OneHotContext op_context{context, node};
+
+ if (IsDynamicTensor(op_context.output)) {
+ ResizeOutputTensor(context, op_context);
+ }
+
+ switch (op_context.output->type) {
+ case kTfLiteFloat32:
+ OneHotCompute<float>(op_context);
+ break;
+ case kTfLiteInt32:
+ OneHotCompute<int>(op_context);
+ break;
+ case kTfLiteInt64:
+ OneHotCompute<int64_t>(op_context);
+ break;
+ case kTfLiteBool:
+ OneHotCompute<bool>(op_context);
+ break;
+ default:
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace one_hot
+
+TfLiteRegistration* Register_ONE_HOT() {
+ static TfLiteRegistration r = {
+ nullptr,
+ nullptr,
+ one_hot::Prepare,
+ one_hot::Eval,
+ };
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/one_hot_test.cc b/tensorflow/contrib/lite/kernels/one_hot_test.cc
new file mode 100644
index 0000000000..6b604ec7a7
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/one_hot_test.cc
@@ -0,0 +1,182 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+template <typename T>
+class OneHotOpModel : public SingleOpModel {
+ public:
+ OneHotOpModel(std::initializer_list<int> input_shape, int depth_value,
+ TensorType dtype, int axis = -1, T on_value = 1,
+ T off_value = 0, TensorType indices_type = TensorType_INT32) {
+ indices_ = AddInput(indices_type);
+ int depth = AddInput(TensorType_INT32);
+ int on = AddInput(dtype);
+ int off = AddInput(dtype);
+ output_ = AddOutput(dtype);
+ SetBuiltinOp(BuiltinOperator_ONE_HOT, BuiltinOptions_OneHotOptions,
+ CreateOneHotOptions(builder_, axis).Union());
+ BuildInterpreter({input_shape});
+
+ PopulateTensor<int>(depth, {depth_value});
+ PopulateTensor<T>(on, {on_value});
+ PopulateTensor<T>(off, {off_value});
+ }
+
+ template <typename TI>
+ void SetIndices(std::initializer_list<TI> data) {
+ PopulateTensor<TI>(indices_, data);
+ }
+
+ TfLiteStatus InvokeWithResult() { return interpreter_->Invoke(); }
+
+ int32_t GetOutputSize() { return GetTensorSize(output_); }
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int indices_;
+ int output_;
+};
+
+TEST(OneHotOpTest, BasicFloat) {
+ const int depth = 3;
+ OneHotOpModel<float> model({3}, depth, TensorType_FLOAT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}));
+}
+
+TEST(OneHotOpTest, BasicInt) {
+ const int depth = 3;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
+}
+
+TEST(OneHotOpTest, BasicBool) {
+ const int depth = 3;
+ OneHotOpModel<bool> model({3}, depth, TensorType_BOOL);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({true, false, false, false, true, false, false,
+ false, true}));
+}
+
+TEST(OneHotOpTest, SmallDepth) {
+ const int depth = 1;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32);
+ model.SetIndices({0, 1, 2});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0}));
+}
+
+TEST(OneHotOpTest, BigDepth) {
+ const int depth = 4;
+ OneHotOpModel<int> model({2}, depth, TensorType_INT32);
+ model.SetIndices({0, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 4}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 0, 1, 0, 0}));
+}
+
+TEST(OneHotOpTest, OnOffValues) {
+ const int depth = 3;
+ const int axis = -1;
+ const int on = 5;
+ const int off = 0;
+ OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
+ model.SetIndices({0, 2, -1, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({4, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 0, 0, 0, 0, 5, 0, 0, 0, 0, 5, 0}));
+}
+
+TEST(OneHotOpTest, ZeroAxis) {
+ const int depth = 3;
+ const int axis = 0;
+ const int on = 5;
+ const int off = 0;
+ OneHotOpModel<int> model({4}, depth, TensorType_INT32, axis, on, off);
+ model.SetIndices({0, 2, -1, 1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 4}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({5, 0, 0, 0, 0, 0, 0, 5, 0, 5, 0, 0}));
+}
+
+TEST(OneHotOpTest, MultiDimensionalIndices) {
+ const int depth = 3;
+ const int axis = -1;
+ const float on = 2;
+ const float off = 0;
+ OneHotOpModel<float> model({2, 2}, depth, TensorType_FLOAT32, axis, on, off);
+ model.SetIndices({0, 2, 1, -1});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 3}));
+ EXPECT_THAT(model.GetOutput(),
+ ElementsAreArray({2, 0, 0, 0, 0, 2, 0, 2, 0, 0, 0, 0}));
+}
+
+TEST(OneHotOpTest, Int64Indices) {
+ const int depth = 3;
+ const int axis = -1;
+ const int on = 1;
+ const int off = 0;
+ OneHotOpModel<int> model({3}, depth, TensorType_INT32, axis, on, off,
+ TensorType_INT64);
+ std::initializer_list<int64_t> indices = {0, 1, 2};
+ model.SetIndices(indices);
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({3, 3}));
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 0, 0, 0, 1, 0, 0, 0, 1}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 0b70bed308..da69b85041 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -107,6 +107,7 @@ TfLiteRegistration* Register_SHAPE();
TfLiteRegistration* Register_POW();
TfLiteRegistration* Register_FAKE_QUANT();
TfLiteRegistration* Register_PACK();
+TfLiteRegistration* Register_ONE_HOT();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -197,6 +198,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_POW, Register_POW());
AddBuiltin(BuiltinOperator_FAKE_QUANT, Register_FAKE_QUANT(), 1, 2);
AddBuiltin(BuiltinOperator_PACK, Register_PACK());
+ AddBuiltin(BuiltinOperator_ONE_HOT, Register_ONE_HOT());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index c6869feb16..5814cddc5b 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -730,6 +730,14 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = static_cast<void*>(params);
break;
}
+ case BuiltinOperator_ONE_HOT: {
+ auto* params = MallocPOD<TfLiteOneHotParams>();
+ if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
+ params->axis = schema_params->axis();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
// Below are the ops with no builtin_data strcture.
case BuiltinOperator_BATCH_TO_SPACE_ND:
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 551e8ed320..1c06b29deb 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -623,6 +623,7 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_FAKE_QUANT:
case tflite::BuiltinOperator_PACK:
case tflite::BuiltinOperator_LOGICAL_OR:
+ case tflite::BuiltinOperator_ONE_HOT:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index a285bf9919..8ed98ddaf4 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -166,6 +166,7 @@ enum BuiltinOperator : byte {
REDUCE_MAX = 82,
PACK = 83,
LOGICAL_OR = 84,
+ ONE_HOT = 85,
}
// Options for the builtin operators.
@@ -230,6 +231,7 @@ union BuiltinOptions {
FakeQuantOptions,
PackOptions,
LogicalOrOptions,
+ OneHotOptions,
}
enum Padding : byte { SAME, VALID }
@@ -549,6 +551,10 @@ table PackOptions {
table LogicalOrOptions {
}
+table OneHotOptions {
+ axis:int;
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 8c1d6d6a36..4402f89b85 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -211,6 +211,9 @@ struct PackOptionsT;
struct LogicalOrOptions;
struct LogicalOrOptionsT;
+struct OneHotOptions;
+struct OneHotOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -361,11 +364,12 @@ enum BuiltinOperator {
BuiltinOperator_REDUCE_MAX = 82,
BuiltinOperator_PACK = 83,
BuiltinOperator_LOGICAL_OR = 84,
+ BuiltinOperator_ONE_HOT = 85,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_LOGICAL_OR
+ BuiltinOperator_MAX = BuiltinOperator_ONE_HOT
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[85] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -450,7 +454,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[84] {
BuiltinOperator_REDUCE_PROD,
BuiltinOperator_REDUCE_MAX,
BuiltinOperator_PACK,
- BuiltinOperator_LOGICAL_OR
+ BuiltinOperator_LOGICAL_OR,
+ BuiltinOperator_ONE_HOT
};
return values;
}
@@ -542,6 +547,7 @@ inline const char **EnumNamesBuiltinOperator() {
"REDUCE_MAX",
"PACK",
"LOGICAL_OR",
+ "ONE_HOT",
nullptr
};
return names;
@@ -614,11 +620,12 @@ enum BuiltinOptions {
BuiltinOptions_FakeQuantOptions = 58,
BuiltinOptions_PackOptions = 59,
BuiltinOptions_LogicalOrOptions = 60,
+ BuiltinOptions_OneHotOptions = 61,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_LogicalOrOptions
+ BuiltinOptions_MAX = BuiltinOptions_OneHotOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[62] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -680,7 +687,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[61] {
BuiltinOptions_ArgMinOptions,
BuiltinOptions_FakeQuantOptions,
BuiltinOptions_PackOptions,
- BuiltinOptions_LogicalOrOptions
+ BuiltinOptions_LogicalOrOptions,
+ BuiltinOptions_OneHotOptions
};
return values;
}
@@ -748,6 +756,7 @@ inline const char **EnumNamesBuiltinOptions() {
"FakeQuantOptions",
"PackOptions",
"LogicalOrOptions",
+ "OneHotOptions",
nullptr
};
return names;
@@ -1002,6 +1011,10 @@ template<> struct BuiltinOptionsTraits<LogicalOrOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_LogicalOrOptions;
};
+template<> struct BuiltinOptionsTraits<OneHotOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_OneHotOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1513,6 +1526,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_LogicalOrOptions ?
reinterpret_cast<const LogicalOrOptionsT *>(value) : nullptr;
}
+ OneHotOptionsT *AsOneHotOptions() {
+ return type == BuiltinOptions_OneHotOptions ?
+ reinterpret_cast<OneHotOptionsT *>(value) : nullptr;
+ }
+ const OneHotOptionsT *AsOneHotOptions() const {
+ return type == BuiltinOptions_OneHotOptions ?
+ reinterpret_cast<const OneHotOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5452,6 +5473,60 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(
flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogicalOrOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct OneHotOptionsT : public flatbuffers::NativeTable {
+ typedef OneHotOptions TableType;
+ int32_t axis;
+ OneHotOptionsT()
+ : axis(0) {
+ }
+};
+
+struct OneHotOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef OneHotOptionsT NativeTableType;
+ enum {
+ VT_AXIS = 4
+ };
+ int32_t axis() const {
+ return GetField<int32_t>(VT_AXIS, 0);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int32_t>(verifier, VT_AXIS) &&
+ verifier.EndTable();
+ }
+ OneHotOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<OneHotOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct OneHotOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_axis(int32_t axis) {
+ fbb_.AddElement<int32_t>(OneHotOptions::VT_AXIS, axis, 0);
+ }
+ explicit OneHotOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ OneHotOptionsBuilder &operator=(const OneHotOptionsBuilder &);
+ flatbuffers::Offset<OneHotOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<OneHotOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ int32_t axis = 0) {
+ OneHotOptionsBuilder builder_(_fbb);
+ builder_.add_axis(axis);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -5765,6 +5840,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const LogicalOrOptions *builtin_options_as_LogicalOrOptions() const {
return builtin_options_type() == BuiltinOptions_LogicalOrOptions ? static_cast<const LogicalOrOptions *>(builtin_options()) : nullptr;
}
+ const OneHotOptions *builtin_options_as_OneHotOptions() const {
+ return builtin_options_type() == BuiltinOptions_OneHotOptions ? static_cast<const OneHotOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6036,6 +6114,10 @@ template<> inline const LogicalOrOptions *Operator::builtin_options_as<LogicalOr
return builtin_options_as_LogicalOrOptions();
}
+template<> inline const OneHotOptions *Operator::builtin_options_as<OneHotOptions>() const {
+ return builtin_options_as_OneHotOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -8151,6 +8233,32 @@ inline flatbuffers::Offset<LogicalOrOptions> CreateLogicalOrOptions(flatbuffers:
_fbb);
}
+inline OneHotOptionsT *OneHotOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new OneHotOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void OneHotOptions::UnPackTo(OneHotOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = axis(); _o->axis = _e; };
+}
+
+inline flatbuffers::Offset<OneHotOptions> OneHotOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateOneHotOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<OneHotOptions> CreateOneHotOptions(flatbuffers::FlatBufferBuilder &_fbb, const OneHotOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OneHotOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _axis = _o->axis;
+ return tflite::CreateOneHotOptions(
+ _fbb,
+ _axis);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -8580,6 +8688,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -8838,6 +8950,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const LogicalOrOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<const OneHotOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9084,6 +9200,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const LogicalOrOptionsT *>(value);
return CreateLogicalOrOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<const OneHotOptionsT *>(value);
+ return CreateOneHotOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9330,6 +9450,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new LogicalOrOptionsT(*reinterpret_cast<LogicalOrOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_OneHotOptions: {
+ value = new OneHotOptionsT(*reinterpret_cast<OneHotOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -9637,6 +9761,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_OneHotOptions: {
+ auto ptr = reinterpret_cast<OneHotOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 41ece94237..3c7ad9d8b3 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -242,7 +242,9 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100):
value = (max_value-min_value)*np.random.random_sample(shape)+min_value
elif dtype in (tf.int32, tf.uint8, tf.int64):
value = np.random.randint(min_value, max_value+1, shape)
- return value.astype(dtype)
+
+ return np.dtype(dtype).type(value) if np.isscalar(value) else value.astype(
+ dtype)
def create_scalar_data(dtype, min_value=-100, max_value=100):
@@ -1665,6 +1667,65 @@ def make_shape_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_one_hot_tests(zip_path):
+ """Make a set of tests to do one_hot."""
+
+ test_parameters = [{
+ "indices_type": [tf.int32, tf.int64],
+ "indices_shape": [[3], [4, 4], [1, 5], [5, 1]],
+ "axis": [0, 1],
+ "dtype": [tf.int32, tf.int64, tf.float32],
+ "provide_optional_inputs": [True, False],
+ }]
+
+ def build_graph(parameters):
+ indices = tf.placeholder(
+ dtype=parameters["indices_type"],
+ name="indices",
+ shape=parameters["indices_shape"])
+ depth = tf.placeholder(dtype=tf.int32, name="depth", shape=())
+
+ if not parameters["provide_optional_inputs"]:
+ out = tf.one_hot(indices=indices, depth=depth)
+ return [indices, depth], [out]
+
+ on_value = tf.placeholder(
+ dtype=parameters["dtype"], name="on_value", shape=())
+ off_value = tf.placeholder(
+ dtype=parameters["dtype"], name="off_value", shape=())
+ out = tf.one_hot(
+ indices=indices,
+ depth=depth,
+ on_value=on_value,
+ off_value=off_value,
+ axis=parameters["axis"],
+ dtype=parameters["dtype"])
+ return [indices, depth, on_value, off_value], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = [
+ create_tensor_data(
+ parameters["indices_type"],
+ shape=parameters["indices_shape"],
+ min_value=-1,
+ max_value=10),
+ create_tensor_data(tf.int32, shape=None, min_value=1, max_value=10),
+ ]
+
+ if parameters["provide_optional_inputs"]:
+ input_values.append(
+ create_tensor_data(
+ parameters["dtype"], shape=None, min_value=1, max_value=10))
+ input_values.append(
+ create_tensor_data(
+ parameters["dtype"], shape=None, min_value=-1, max_value=0))
+
+ return input_values, sess.run(
+ outputs, feed_dict=dict(zip(inputs, input_values)))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_resize_bilinear_tests(zip_path):
"""Make a set of tests to do resize_bilinear."""
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index b79bb300f0..9983e59910 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1316,6 +1316,20 @@ void ConvertResizeBilinearOperator(const Model& model,
(*resize_op->mutable_attr())["align_corners"].set_b(src_op.align_corners);
}
+void ConvertOneHotOperator(const Model& model, const OneHotOperator& src_op,
+ GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* onehot_op = tensorflow_graph->add_node();
+ onehot_op->set_op("OneHot");
+ onehot_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 4);
+ for (const auto& input : src_op.inputs) {
+ *onehot_op->add_input() = input;
+ }
+ (*onehot_op->mutable_attr())["T"].set_type(
+ GetTensorFlowDataType(model, src_op.outputs[0]));
+ (*onehot_op->mutable_attr())["axis"].set_i(src_op.axis);
+}
+
namespace {
// TODO(aselle): Remove when available in absl
absl::string_view FindLongestCommonPrefix(absl::string_view a,
@@ -2158,6 +2172,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertLogicalNotOperator(model,
static_cast<const LogicalNotOperator&>(src_op),
tensorflow_graph);
+ } else if (src_op.type == OperatorType::kOneHot) {
+ ConvertOneHotOperator(model, static_cast<const OneHotOperator&>(src_op),
+ tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index 9c22497d5e..0f94006f34 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -201,6 +201,18 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
SetDataTypeForAllOutputs(model, op, data_type);
break;
}
+ case OperatorType::kOneHot: {
+ CHECK_EQ(op->inputs.size(), 4);
+ CHECK_EQ(op->outputs.size(), 1);
+ const ArrayDataType on_value_type =
+ model->GetArray(op->inputs[OneHotOperator::ON_VALUE_INPUT]).data_type;
+ const ArrayDataType off_value_type =
+ model->GetArray(op->inputs[OneHotOperator::OFF_VALUE_INPUT])
+ .data_type;
+ CHECK(on_value_type == off_value_type);
+ model->GetArray(op->outputs[0]).data_type = on_value_type;
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index a03b589bae..5aa0fddf57 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1578,6 +1578,61 @@ void ProcessAnyOperator(Model* model, AnyOperator* op) {
}
}
+void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
+ CHECK_EQ(op->inputs.size(), 4);
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.has_shape()) {
+ // Shape already propagated
+ return;
+ }
+
+ // Yield until indices dims have been resolved.
+ const auto& indices_array =
+ model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
+ if (!indices_array.has_shape()) {
+ return;
+ }
+
+ // Yield until depth is constant and dims have been resolved.
+ if (!IsConstantParameterArray(*model,
+ op->inputs[OneHotOperator::DEPTH_INPUT])) {
+ return;
+ }
+ const auto& depth_array =
+ model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
+ if (!depth_array.has_shape()) {
+ return;
+ }
+
+ CHECK(depth_array.data_type == ArrayDataType::kInt32)
+ << "Depth array must be int32.";
+ CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
+ << "Depth array must be scalar.";
+
+ const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
+ CHECK_GE(depth, 0) << "Depth must be non-negative.";
+
+ const int indices_dims = indices_array.shape().dimensions_count();
+ const int output_dims = indices_dims + 1;
+ const int axis = op->axis == -1 ? indices_dims : op->axis;
+ CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->resize(output_dims);
+ for (int i = 0; i < output_dims; ++i) {
+ int dim = 0;
+ if (i < axis) {
+ dim = indices_array.shape().dims(i);
+ } else if (i == axis) {
+ dim = depth;
+ } else {
+ dim = indices_array.shape().dims(i - 1);
+ }
+ (*mutable_dims)[i] = dim;
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1825,6 +1880,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kAny:
ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
break;
+ case OperatorType::kOneHot:
+ ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index f36f720857..f92f33497d 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1833,6 +1833,27 @@ tensorflow::Status ConvertSparseToDenseOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertOneHotOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "OneHot");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
+
+ const auto dtype = GetDataTypeAttr(node, "T");
+ // TODO(b/111744875): Support DT_UINT8 and quantization.
+ CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
+ dtype == DT_BOOL);
+
+ auto op = absl::make_unique<OneHotOperator>();
+ op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
+ for (const string& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -1909,6 +1930,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
{"NoOp", ConvertNoOpOperator},
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
+ {"OneHot", ConvertOneHotOperator},
{"Pack", ConvertPackOperator},
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 6459dccf64..a3827977fd 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -64,6 +64,7 @@ enum class OperatorType : uint8 {
kMaxPool,
kFakeQuant,
kMul,
+ kOneHot,
kRandomUniform,
kRange,
kRank,
@@ -1768,6 +1769,27 @@ struct LogicalNotOperator : Operator {
LogicalNotOperator() : Operator(OperatorType::kLogicalNot) {}
};
+// OneHot operator:
+//
+// Inputs:
+// Inputs[0]: required: indices.
+// Inputs[1]: required: depth.
+// Inputs[2]: required: on_value.
+// Inputs[3]: required: off_value.
+//
+// TensorFlow equivalent: OneHot.
+struct OneHotOperator : Operator {
+ enum Inputs {
+ INDICES_INPUT = 0,
+ DEPTH_INPUT = 1,
+ ON_VALUE_INPUT = 2,
+ OFF_VALUE_INPUT = 3,
+ };
+
+ OneHotOperator() : Operator(OperatorType::kOneHot) {}
+ int axis = -1;
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 4b2ef756cc..769e350ea9 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1053,6 +1053,23 @@ class Shape
int GetVersion(const Operator& op) const override { return 1; }
};
+class OneHot : public BuiltinOperator<OneHotOperator, ::tflite::OneHotOptions,
+ ::tflite::BuiltinOptions_OneHotOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateOneHotOptions(*builder, op.axis);
+ }
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->axis = options.axis();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
using BaseOperator::BaseOperator;
@@ -1278,6 +1295,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kFakeQuant));
ops.emplace_back(
new Pack(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
+ ops.emplace_back(
+ new OneHot(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot));
// Custom Operators.
ops.emplace_back(
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index 44de6fbf64..7e1e32ae54 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -462,6 +462,14 @@ TEST_F(OperatorTest, BuiltinPack) {
EXPECT_EQ(op.axis, output_toco_op->axis);
}
+TEST_F(OperatorTest, BuiltinOneHot) {
+ OneHotOperator op;
+ op.axis = 2;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("ONE_HOT", OperatorType::kOneHot), op);
+ EXPECT_EQ(op.axis, output_toco_op->axis);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 98e416b76e..93c30dd0f8 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -356,6 +356,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(ReduceMin) // Reduction Min
HANDLE_OPERATORTYPENAME_CASE(Minimum) // Element-wise Minimum
HANDLE_OPERATORTYPENAME_CASE(Neg)
+ HANDLE_OPERATORTYPENAME_CASE(OneHot)
HANDLE_OPERATORTYPENAME_CASE(Pack)
HANDLE_OPERATORTYPENAME_CASE(Pad)
HANDLE_OPERATORTYPENAME_CASE(PadV2)