diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-25 08:20:51 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-25 08:28:58 -0800 |
commit | 10d7ddfa9bb95d65f7245dae4230a00b0badde06 (patch) | |
tree | 95619c94cb2df696c90c8728694778ab71c51e83 /tensorflow/contrib/lite/toco | |
parent | 028ef1e67201700e8d9d77af64655f1dd20ae665 (diff) |
Automated g4 rollback of changelist 183239252
PiperOrigin-RevId: 183241034
Diffstat (limited to 'tensorflow/contrib/lite/toco')
4 files changed, 25 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc index 49cc1fc2aa..5961d30bf5 100644 --- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc +++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc @@ -158,7 +158,9 @@ std::size_t TransientArraySize(const Model& model, const string& array_name, LOG(FATAL) << "A RNN state array, " << array_name << ", still does not " << "have a known data type after all graph transformations have " - << "run."; + << "run. That's mostly a toco bug --- sorry. For now, you can " + << "work around this issue by adding manually_create:true in the " + << "--rnn_state description of this RNN state."; } } LOG(FATAL) << "An array, " << array_name << ", still does not " diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc index 36520d9c55..790b3443ce 100644 --- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc @@ -327,6 +327,9 @@ void ReadModelFlagsFromCommandLineFlags( CHECK(absl::SimpleAtoi(value, &size)); CHECK_GT(size, 0); rnn_state_proto->set_size(size); + } else if (key == "manually_create") { + CHECK_EQ(absl::AsciiStrToLower(value), "true"); + rnn_state_proto->set_manually_create(true); } else { LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states"; } diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index 9070ddc883..13fea29a07 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -81,10 +81,19 @@ message RnnState { optional string state_array = 1; optional string back_edge_source_array = 2; optional bool discardable = 5; - // size allows to specify a 1-D shape for the RNN state array. - // Will be expanded with 1's to fit the model. - // TODO(benoitjacob): should allow a generic, explicit shape. + // TODO(benoitjacob): drop the 'size' field. Should be redundant with + // --input_shapes and shapes propagation. optional int32 size = 3; + // TODO(benoitjacob): manually_create is a temporary hack: + // due to discrepancies between the current toco dims tracking and + // TensorFlow shapes, for some models we need to manually create RNN state + // arrays with a specified shape. + // Maybe we should actually implement back-edges as operators of their own, + // which would remove the need for much special-casing, including here, + // we could probably consistently let PropagateFixedSizes handle state + // arrays. + // TODO(benoitjacob): should really drop manually_create now. + optional bool manually_create = 4; } // ModelFlags encodes properties of a model that, depending on the file diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index df785a5102..99a54a300b 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -958,9 +958,7 @@ void CheckModelCounts(const Model& model) { void MakeArrayDims(int num_dims, int batch, int height, int width, int depth, std::vector<int>* out_dims) { CHECK(out_dims->empty()); - if (num_dims == 0) { - return; - } else if (num_dims == 1) { + if (num_dims == 1) { CHECK_EQ(batch, 1); *out_dims = {depth}; } else if (num_dims == 2) { @@ -992,13 +990,13 @@ void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { if (array.has_shape()) { num_dims = array.shape().dimensions_count(); } + std::vector<int> dims; + MakeArrayDims(num_dims, batch, 1, 1, size, &dims); CHECK(array.data_type == ArrayDataType::kFloat || array.data_type == ArrayDataType::kNone); array.data_type = ArrayDataType::kFloat; - if (!array.has_shape() && num_dims >= 0) { + if (!array.has_shape()) { Shape* shape = array.mutable_shape(); - std::vector<int> dims; - MakeArrayDims(num_dims, batch, 1, 1, size, &dims); *shape->mutable_dims() = dims; } } @@ -1187,6 +1185,9 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { } // Creation of the RNN state arrays for (const auto& rnn_state : model->flags.rnn_states()) { + if (!rnn_state.manually_create()) { + continue; + } CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(), model); } |