aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/import_tensorflow.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-06-20 11:48:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 11:51:26 -0700
commit4efefb90391b12c95339ed3b46a02b62ea5e195d (patch)
treebb3f9bb986b89287983ea8e7c35827993aad7206 /tensorflow/contrib/lite/toco/import_tensorflow.cc
parente51df5918020cdfada26022240091e5529f7da60 (diff)
Implement TFLite Shape operator
PiperOrigin-RevId: 201389618
Diffstat (limited to 'tensorflow/contrib/lite/toco/import_tensorflow.cc')
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc18
1 files changed, 17 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index caca199d2e..8da33e8a22 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1573,6 +1573,22 @@ tensorflow::Status ConvertOperatorSpecialCasedAsRNNBackEdge(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertShapeOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Shape");
+ TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 1));
+ const auto out_type =
+ HasAttr(node, "out_type") ? GetDataTypeAttr(node, "out_type") : DT_INT32;
+ CHECK(out_type == DT_INT64 || out_type == DT_INT32);
+ auto op = absl::make_unique<TensorFlowShapeOperator>();
+ op->output_data_type = ConvertDataType(out_type);
+ op->inputs.push_back(node.input(0));
+ op->outputs.push_back(node.name());
+ model->operators.push_back(std::move(op));
+ return tensorflow::Status::OK();
+}
+
void StripCaretFromArrayNames(Model* model) {
for (auto& op : model->operators) {
for (auto& input : op->inputs) {
@@ -1877,7 +1893,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"ResizeBilinear", ConvertResizeBilinearOperator},
{"Rsqrt", ConvertSimpleOperator<TensorFlowRsqrtOperator, 1>},
{"Select", ConvertSimpleOperator<SelectOperator, 3>},
- {"Shape", ConvertSimpleOperator<TensorFlowShapeOperator, 1>},
+ {"Shape", ConvertShapeOperator},
{"Sigmoid", ConvertSimpleOperator<LogisticOperator, 1>},
{"Sin", ConvertSimpleOperator<SinOperator, 1>},
{"Slice", ConvertSimpleOperator<SliceOperator, 3>},