aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-07-26 10:53:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 10:56:43 -0700
commit6e658c0a5ca77677a954a34fb98f241c592c970d (patch)
treeb645103887539af5232b3f70d80a2eb9b77ed63a /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent0a3155f7fbf56df5e81c7cbf35afd45173359635 (diff)
Add one_hot op support to TFLite
PiperOrigin-RevId: 206185190
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index f36f720857..f92f33497d 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1833,6 +1833,27 @@ tensorflow::Status ConvertSparseToDenseOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertOneHotOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "OneHot");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 4));
+
+ const auto dtype = GetDataTypeAttr(node, "T");
+ // TODO(b/111744875): Support DT_UINT8 and quantization.
+ CHECK(dtype == DT_INT32 || dtype == DT_INT64 || dtype == DT_FLOAT ||
+ dtype == DT_BOOL);
+
+ auto op = absl::make_unique<OneHotOperator>();
+ op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : -1;
+ for (const string& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+ model->operators.emplace_back(op.release());
+ return tensorflow::Status::OK();
+}
+
} // namespace
namespace internal {
@@ -1909,6 +1930,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"NextIteration", ConvertOperatorSpecialCasedAsRNNBackEdge},
{"NoOp", ConvertNoOpOperator},
{"NotEqual", ConvertSimpleOperator<TensorFlowNotEqualOperator, 2>},
+ {"OneHot", ConvertOneHotOperator},
{"Pack", ConvertPackOperator},
{"Pad", ConvertSimpleOperator<PadOperator, 2>},
{"PadV2", ConvertSimpleOperator<PadV2Operator, 3>},