aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-31 06:05:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 06:07:59 -0700
commit7e2e57410eb40c0512dc573955fd256a6c787741 (patch)
treeec345a16ed486ec5a964ac5d6be20bde7d7b401c /tensorflow/contrib/lite/toco/import_tensorflow.cc
parentca4bda919793cc2578e5c0f7440525261da16fdf (diff)
implementation of sparse_to_dense
PiperOrigin-RevId: 198710452
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 27e9d1af88..94ec7c24d4 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -2133,6 +2133,24 @@ void ConvertDynamicStitchOperator(const NodeDef& node,
model->operators.emplace_back(op.release());
}
+void ConvertSparseToDenseOperator(const NodeDef& node,
+ const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "SparseToDense");
+ CheckInputsCount(node, tf_import_flags, 4);
+
+ auto* op = new SparseToDenseOperator;
+ for (const string& input : node.input()) {
+ op->inputs.push_back(input);
+ }
+ op->outputs.push_back(node.name());
+
+ op->validate_indices = HasAttr(node, "validate_indices")
+ ? GetBoolAttr(node, "validate_indices")
+ : true;
+ model->operators.emplace_back(op);
+}
+
} // namespace
namespace internal {
@@ -2314,6 +2332,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node,
ConvertSinOperator(node, tf_import_flags, model);
} else if (node.op() == "Select") {
ConvertSelectOperator(node, tf_import_flags, model);
+ } else if (node.op() == "SparseToDense") {
+ ConvertSparseToDenseOperator(node, tf_import_flags, model);
} else {
ConvertUnsupportedOperator(node, tf_import_flags, model);
}