diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-08 13:25:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-08 13:30:15 -0800 |
commit | dc04e89bc6f0421bf77ac69f21c1f2f57618f53c (patch) | |
tree | 49b05183e6a6dd2e8a5afff97a0fe35615de1127 /tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | |
parent | 1afc6149ed0649971d83fe8e9748056285dcf332 (diff) |
Adding support for new TensorFlow operators. Also adding a transformation to convert an ExpandDims into a Reshape op.
PiperOrigin-RevId: 178418377
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 550e0408aa..1d92bcbccd 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -59,13 +59,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { op->type == OperatorType::kTensorFlowGreaterEqual) { // These operators unconditionally produce bool outputs SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool); - } else if (op->type == OperatorType::kTensorFlowShape) { + } else if (op->type == OperatorType::kRank || + op->type == OperatorType::kTensorFlowShape) { // These operators are assumed to produce int32 outputs. SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32); } else if (op->type == OperatorType::kTensorFlowSplit || - op->type == OperatorType::kTensorFlowConcat) { + op->type == OperatorType::kTensorFlowConcat || + op->type == OperatorType::kFill) { // These operators produce an output with the same type as their 2nd input - CHECK_GT(op->inputs.size(), 1); + CHECK_GE(op->inputs.size(), 2); const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type; SetDataTypeForAllOutputs(model, op, data_type); } else if (op->type == OperatorType::kCast) { @@ -83,6 +85,9 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { auto data_type = unsupported_op->output_data_types[i]; model->arrays[output]->data_type = data_type; } + } else if (op->type == OperatorType::kExpandDims) { + // Yield on ExpandDim until it is converted to Reshape + return false; } else { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); |