diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index a03b589bae..5aa0fddf57 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1578,6 +1578,61 @@ void ProcessAnyOperator(Model* model, AnyOperator* op) { } } +void ProcessOneHotOperator(Model* model, OneHotOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.has_shape()) { + // Shape already propagated + return; + } + + // Yield until indices dims have been resolved. + const auto& indices_array = + model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]); + if (!indices_array.has_shape()) { + return; + } + + // Yield until depth is constant and dims have been resolved. + if (!IsConstantParameterArray(*model, + op->inputs[OneHotOperator::DEPTH_INPUT])) { + return; + } + const auto& depth_array = + model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]); + if (!depth_array.has_shape()) { + return; + } + + CHECK(depth_array.data_type == ArrayDataType::kInt32) + << "Depth array must be int32."; + CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1) + << "Depth array must be scalar."; + + const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0]; + CHECK_GE(depth, 0) << "Depth must be non-negative."; + + const int indices_dims = indices_array.shape().dimensions_count(); + const int output_dims = indices_dims + 1; + const int axis = op->axis == -1 ? indices_dims : op->axis; + CHECK_GE(axis, 0) << "Resolved axis must be non-negative."; + + auto* mutable_dims = output_array.mutable_shape()->mutable_dims(); + mutable_dims->resize(output_dims); + for (int i = 0; i < output_dims; ++i) { + int dim = 0; + if (i < axis) { + dim = indices_array.shape().dims(i); + } else if (i == axis) { + dim = depth; + } else { + dim = indices_array.shape().dims(i - 1); + } + (*mutable_dims)[i] = dim; + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1825,6 +1880,9 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kAny: ProcessAnyOperator(model, static_cast<AnyOperator*>(op)); break; + case OperatorType::kOneHot: + ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); |