aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-23 21:50:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 21:57:27 -0700
commit4bc1d3e484c6eb3ea2ba4e6400722be32220c808 (patch)
treeb6f760b3003355257f57b2441ac273e7349e3b03 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent0c657f3b9f6ef6ee63b3eb54fe928f482c58dc80 (diff)
Implementation of unpack op.
PiperOrigin-RevId: 210051131
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 91e290439a..fa2be961f5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1629,6 +1629,32 @@ void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
}
}
+void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
+ CHECK_EQ(op->inputs.size(), 1);
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+
+ output_dims.reserve(input_dims.size() - 1);
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (i != op->axis) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ for (const string& output_name : op->outputs) {
+ auto& output_array = model->GetArray(output_name);
+ if (output_array.has_shape()) {
+ return;
+ }
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1880,6 +1906,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kOneHot:
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
break;
+ case OperatorType::kUnpack:
+ ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);