diff options
author | Jared Duke <jdduke@google.com> | 2018-06-20 11:48:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 11:51:26 -0700 |
commit | 4efefb90391b12c95339ed3b46a02b62ea5e195d (patch) | |
tree | bb3f9bb986b89287983ea8e7c35827993aad7206 /tensorflow/contrib/lite/toco/import_tensorflow.cc | |
parent | e51df5918020cdfada26022240091e5529f7da60 (diff) |
Implement TFLite Shape operator
PiperOrigin-RevId: 201389618
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/import_tensorflow.cc | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index caca199d2e..8da33e8a22 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -1573,6 +1573,22 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge( return tensorflow::Status::OK(); } +tensorflow::Status ConvertShapeOperator( + const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Shape"); + TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1)); + const auto out_type = + HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32; + CHECK(out_type == DT_INT64 || out_type == DT_INT32); + auto op = absl::make_unique<TensorFlowShapeOperator>(); + op->output_data_type = ConvertDataType(out_type); + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.push_back(std::move(op)); + return tensorflow::Status::OK(); +} + void StripCaretFromArrayNames(Model* model) { for (auto& op : model->operators) { for (auto& input : op->inputs) { @@ -1877,7 +1893,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() { {"ResizeBilinear", ConvertResizeBilinearOperator}, {"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>}, {"Select", ConvertSimpleOperator<SelectOperator, 3>}, - {"Shape", ConvertSimpleOperator<TensorFlowShapeOperator, 1>}, + {"Shape", ConvertShapeOperator}, {"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1>}, {"Sin", ConvertSimpleOperator<SinOperator, 1>}, {"Slice", ConvertSimpleOperator<SliceOperator, 3>}, |