diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-26 08:13:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-26 08:16:27 -0700 |
commit | 9640698f25a9c69ebd496d976d50a3b6e0c6431c (patch) | |
tree | 4b808a49499400c2d11f92960055943482b1c08c /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | c15341ed83697929c1075d411e1607adadf0dbe4 (diff) |
Automated g4 rollback of changelist 202120850
PiperOrigin-RevId: 202130585
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 01a51802d4..c61da203c6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -133,7 +133,36 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) { } } +bool EnsureBiasVectorShape(Model* model, Operator* op) { + const string& weights_name = op->inputs[1]; + const auto& weights_array = model->GetArray(weights_name); + // Yield until weights shape has been resolved. + if (!weights_array.has_shape()) { + return false; + } + + if (op->inputs.size() < 3) { + return false; + } + auto& bias_array = model->GetArray(op->inputs[2]); + if (bias_array.has_shape()) { + return true; + } + + const int output_depth = GetOutputDepthFromWeights(*model, *op); + bias_array.copy_shape(Shape({output_depth})); + + auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>(); + float_buffer.data.resize(output_depth, 0); + + return true; +} + void ProcessConvOperator(Model* model, ConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -263,6 +292,10 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) { } void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { @@ -377,6 +410,10 @@ void ProcessOpWithShapeInput(Model* model, Operator* op) { } void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) { + if (!EnsureBiasVectorShape(model, op)) { + return; + } + const auto& input_array = model->GetArray(op->inputs[0]); // Yield until input dims have been resolved. if (!input_array.has_shape()) { |