aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-25 08:20:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 08:28:58 -0800
commit10d7ddfa9bb95d65f7245dae4230a00b0badde06 (patch)
tree95619c94cb2df696c90c8728694778ab71c51e83 /tensorflow
parent028ef1e67201700e8d9d77af64655f1dd20ae665 (diff)
Automated g4 rollback of changelist 183239252
PiperOrigin-RevId: 183241034
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc4
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc3
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto15
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc13
4 files changed, 25 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index 49cc1fc2aa..5961d30bf5 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -158,7 +158,9 @@ std::size_t TransientArraySize(const Model& model, const string& array_name,
LOG(FATAL)
<< "A RNN state array, " << array_name << ", still does not "
<< "have a known data type after all graph transformations have "
- << "run.";
+ << "run. That's mostly a toco bug --- sorry. For now, you can "
+ << "work around this issue by adding manually_create:true in the "
+ << "--rnn_state description of this RNN state.";
}
}
LOG(FATAL) << "An array, " << array_name << ", still does not "
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 36520d9c55..790b3443ce 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -327,6 +327,9 @@ void ReadModelFlagsFromCommandLineFlags(
CHECK(absl::SimpleAtoi(value, &size));
CHECK_GT(size, 0);
rnn_state_proto->set_size(size);
+ } else if (key == "manually_create") {
+ CHECK_EQ(absl::AsciiStrToLower(value), "true");
+ rnn_state_proto->set_manually_create(true);
} else {
LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
}
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
index 9070ddc883..13fea29a07 100644
--- a/tensorflow/contrib/lite/toco/model_flags.proto
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -81,10 +81,19 @@ message RnnState {
optional string state_array = 1;
optional string back_edge_source_array = 2;
optional bool discardable = 5;
- // size allows to specify a 1-D shape for the RNN state array.
- // Will be expanded with 1's to fit the model.
- // TODO(benoitjacob): should allow a generic, explicit shape.
+ // TODO(benoitjacob): drop the 'size' field. Should be redundant with
+ // --input_shapes and shapes propagation.
optional int32 size = 3;
+ // TODO(benoitjacob): manually_create is a temporary hack:
+ // due to discrepancies between the current toco dims tracking and
+ // TensorFlow shapes, for some models we need to manually create RNN state
+ // arrays with a specified shape.
+ // Maybe we should actually implement back-edges as operators of their own,
+ // which would remove the need for much special-casing, including here,
+ // we could probably consistently let PropagateFixedSizes handle state
+ // arrays.
+ // TODO(benoitjacob): should really drop manually_create now.
+ optional bool manually_create = 4;
}
// ModelFlags encodes properties of a model that, depending on the file
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index df785a5102..99a54a300b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -958,9 +958,7 @@ void CheckModelCounts(const Model& model) {
void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
std::vector<int>* out_dims) {
CHECK(out_dims->empty());
- if (num_dims == 0) {
- return;
- } else if (num_dims == 1) {
+ if (num_dims == 1) {
CHECK_EQ(batch, 1);
*out_dims = {depth};
} else if (num_dims == 2) {
@@ -992,13 +990,13 @@ void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
if (array.has_shape()) {
num_dims = array.shape().dimensions_count();
}
+ std::vector<int> dims;
+ MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
CHECK(array.data_type == ArrayDataType::kFloat ||
array.data_type == ArrayDataType::kNone);
array.data_type = ArrayDataType::kFloat;
- if (!array.has_shape() && num_dims >= 0) {
+ if (!array.has_shape()) {
Shape* shape = array.mutable_shape();
- std::vector<int> dims;
- MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
*shape->mutable_dims() = dims;
}
}
@@ -1187,6 +1185,9 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
}
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (!rnn_state.manually_create()) {
+ continue;
+ }
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
model);
}