diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-18 14:12:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-18 14:15:24 -0700 |
commit | 4ca04537c0d1d75ea37944aa3bb2dc749428031a (patch) | |
tree | 663b7a6b36e560258dd95a603e0440a4c7b278ac /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | 44af531d952a35c887770ecc4cfddfb0431c2478 (diff) |
Import/export support for Any, LogicalAnd, and LogicalNot ops.
PiperOrigin-RevId: 205134621
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 39 |
1 files changed, 22 insertions, 17 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 9dde7a8bd6..8bb797fe0f 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1042,22 +1042,6 @@ tensorflow::Status ConvertSimpleOperator( return ConvertSimpleOperator<Op>(node, tf_import_flags, model); } -tensorflow::Status ConvertMinOperator( - const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, - Model* model) { - CHECK_EQ(node.op(), "Min"); - TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); - auto* op = new TensorFlowMinOperator; - op->inputs.push_back(node.input(0)); - op->inputs.push_back(node.input(1)); - op->outputs.push_back(node.name()); - model->operators.emplace_back(op); - if (HasAttr(node, "keep_dims")) { - op->keep_dims = GetBoolAttr(node, "keep_dims"); - } - return tensorflow::Status::OK(); -} - tensorflow::Status ConvertUnsupportedOperator( const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1594,6 +1578,24 @@ tensorflow::Status ConvertShapeOperator( return tensorflow::Status::OK(); } +tensorflow::Status ConvertAnyOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Any"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2)); + const auto idx_type = + HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32; + CHECK(idx_type == DT_INT32); + auto op = absl::make_unique<AnyOperator>(); + op->inputs.push_back(node.input(0)); + op->inputs.push_back(node.input(1)); + op->outputs.push_back(node.name()); + op->keep_dims = + HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false; + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1829,6 +1831,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"Add", ConvertSimpleOperator<AddOperator, 2>}, {"AddN", ConvertSimpleOperator<AddNOperator>}, {"All", ConvertSimpleOperator<TensorFlowAllOperator>}, + {"Any", ConvertAnyOperator}, {"ArgMax", ConvertArgMinMaxOperator<ArgMaxOperator, kArgMax>}, {"ArgMin", ConvertArgMinMaxOperator<ArgMinOperator, kArgMin>}, {"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>}, @@ -1872,13 +1875,15 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"LessEqual", ConvertSimpleOperator<TensorFlowLessEqualOperator, 2>}, {"Log", ConvertSimpleOperator<LogOperator, 1>}, {"LogSoftmax", ConvertSimpleOperator<LogSoftmaxOperator, 1>}, + {"LogicalAnd", ConvertSimpleOperator<LogicalAndOperator, 2>}, + {"LogicalNot", ConvertSimpleOperator<LogicalNotOperator, 1>}, {"MatMul", ConvertMatMulOperator}, {"Max", ConvertReduceOperator<TensorFlowMaxOperator>}, {"MaxPool", ConvertMaxPoolOperator}, {"Maximum", ConvertSimpleOperator<TensorFlowMaximumOperator, 2>}, {"Mean", ConvertReduceOperator<MeanOperator>}, {"Merge", ConvertSimpleOperator<TensorFlowMergeOperator, 2>}, - {"Min", ConvertMinOperator}, + {"Min", ConvertReduceOperator<TensorFlowMinOperator>}, {"Minimum", ConvertSimpleOperator<TensorFlowMinimumOperator, 2>}, {"Mul", ConvertSimpleOperator<MulOperator, 2>}, {"Neg", ConvertSimpleOperator<NegOperator, 1>}, |