aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-26 08:13:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 08:16:27 -0700
commit9640698f25a9c69ebd496d976d50a3b6e0c6431c (patch)
tree4b808a49499400c2d11f92960055943482b1c08c /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parentc15341ed83697929c1075d411e1607adadf0dbe4 (diff)
Automated g4 rollback of changelist 202120850
PiperOrigin-RevId: 202130585
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 01a51802d4..c61da203c6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -133,7 +133,36 @@ int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
}
}
+bool EnsureBiasVectorShape(Model* model, Operator* op) {
+ const string& weights_name = op->inputs[1];
+ const auto& weights_array = model->GetArray(weights_name);
+ // Yield until weights shape has been resolved.
+ if (!weights_array.has_shape()) {
+ return false;
+ }
+
+ if (op->inputs.size() < 3) {
+ return false;
+ }
+ auto& bias_array = model->GetArray(op->inputs[2]);
+ if (bias_array.has_shape()) {
+ return true;
+ }
+
+ const int output_depth = GetOutputDepthFromWeights(*model, *op);
+ bias_array.copy_shape(Shape({output_depth}));
+
+ auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ float_buffer.data.resize(output_depth, 0);
+
+ return true;
+}
+
void ProcessConvOperator(Model* model, ConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
@@ -263,6 +292,10 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
}
void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {
@@ -377,6 +410,10 @@ void ProcessOpWithShapeInput(Model* model, Operator* op) {
}
void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
const auto& input_array = model->GetArray(op->inputs[0]);
// Yield until input dims have been resolved.
if (!input_array.has_shape()) {