aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-18 14:12:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 14:15:24 -0700
commit4ca04537c0d1d75ea37944aa3bb2dc749428031a (patch)
tree663b7a6b36e560258dd95a603e0440a4c7b278ac /tensorflow/contrib/lite/toco/import_tensorflow.cc
parent44af531d952a35c887770ecc4cfddfb0431c2478 (diff)
Import/export support for Any, LogicalAnd, and LogicalNot ops.
PiperOrigin-RevId: 205134621
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc39
1 files changed, 22 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 9dde7a8bd6..8bb797fe0f 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1042,22 +1042,6 @@ tensorflow::Status ConvertSimpleOperator(
return ConvertSimpleOperator<Op>(node, tf_import_flags, model);
}
-tensorflow::Status ConvertMinOperator(
- const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Min");
- TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- auto* op = new TensorFlowMinOperator;
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- model->operators.emplace_back(op);
- if (HasAttr(node, "keep_dims")) {
- op->keep_dims = GetBoolAttr(node, "keep_dims");
- }
- return tensorflow::Status::OK();
-}
-
tensorflow::Status ConvertUnsupportedOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -1594,6 +1578,24 @@ tensorflow::Status ConvertShapeOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertAnyOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Any");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
+ const auto idx_type =
+ HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
+ CHECK(idx_type == DT_INT32);
+ auto op = absl::make_unique<AnyOperator>();
+ op->inputs.push_back(node.input(0));
+ op->inputs.push_back(node.input(1));
+ op->outputs.push_back(node.name());
+ op->keep_dims =
+ HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false;
+ model->operators.push_back(std::move(op));
+ return tensorflow::Status::OK();
+}
+
void StripCaretFromArrayNames(Model* model) {
for (auto& op : model->operators) {
for (auto& input : op->inputs) {
@@ -1829,6 +1831,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
+ {"Any", ConvertAnyOperator},
{"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>},
{"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
@@ -1872,13 +1875,15 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>},
{"Log", ConvertSimpleOperator<LogOperator, 1>},
{"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>},
+ {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>},
+ {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>},
{"MatMul", ConvertMatMulOperator},
{"Max", ConvertReduceOperator<TensorFlowMaxOperator>},
{"MaxPool", ConvertMaxPoolOperator},
{"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>},
{"Mean", ConvertReduceOperator<MeanOperator>},
{"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>},
- {"Min", ConvertMinOperator},
+ {"Min", ConvertReduceOperator<TensorFlowMinOperator>},
{"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>},
{"Mul", ConvertSimpleOperator<MulOperator, 2>},
{"Neg", ConvertSimpleOperator<NegOperator, 1>},