aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-25 14:42:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 14:49:24 -0800
commitb998b7b456066530dd27ef532dae195d27505266 (patch)
tree86a8bc1c7c689c2ce63e67271ceb48aa9412d8d3 /tensorflow/contrib
parent73b4b1502924acd461013d4ecf9825aedd3a3968 (diff)
Drop the manually_create field from RnnState.
Initially, I thought that the shape of RNN state arrays could always be determined by shape propagation. Then I came across some graphs where this wasn't so easy to infer, so I introduced manually_create thinking of it as a hack. Today I took another look at dropping that hack, and had a "D'oh" moment when I realized that the cyclic nature of RNN graphs makes it impossible to infer the shapes of all arrays by usual propagation. For example, in a LSTM cell, the input array is concatenated with a state array, so if we don't already know the shape of that state array, shape propagation stops there. Thus, this change removes manually_create by making toco always behave as if manually_create=true, i.e. early-creating all RNN state arrays with the shape explicitly specified by the user. The next TODO item here (see model_flags.proto) is to introduce a generic 'shape' field, so far the current 'size' field only allows specifying 1-D shapes. PiperOrigin-RevId: 183294102
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/toco/allocate_transient_arrays.cc4
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc3
-rw-r--r--tensorflow/contrib/lite/toco/model_flags.proto15
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc13
4 files changed, 10 insertions, 25 deletions
diff --git a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
index 5961d30bf5..49cc1fc2aa 100644
--- a/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
+++ b/tensorflow/contrib/lite/toco/allocate_transient_arrays.cc
@@ -158,9 +158,7 @@ std::size_t TransientArraySize(const Model& model, const string& array_name,
LOG(FATAL)
<< "A RNN state array, " << array_name << ", still does not "
<< "have a known data type after all graph transformations have "
- << "run. That's mostly a toco bug --- sorry. For now, you can "
- << "work around this issue by adding manually_create:true in the "
- << "--rnn_state description of this RNN state.";
+ << "run.";
}
}
LOG(FATAL) << "An array, " << array_name << ", still does not "
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index 790b3443ce..36520d9c55 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -327,9 +327,6 @@ void ReadModelFlagsFromCommandLineFlags(
CHECK(absl::SimpleAtoi(value, &size));
CHECK_GT(size, 0);
rnn_state_proto->set_size(size);
- } else if (key == "manually_create") {
- CHECK_EQ(absl::AsciiStrToLower(value), "true");
- rnn_state_proto->set_manually_create(true);
} else {
LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
}
diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto
index 13fea29a07..9070ddc883 100644
--- a/tensorflow/contrib/lite/toco/model_flags.proto
+++ b/tensorflow/contrib/lite/toco/model_flags.proto
@@ -81,19 +81,10 @@ message RnnState {
optional string state_array = 1;
optional string back_edge_source_array = 2;
optional bool discardable = 5;
- // TODO(benoitjacob): drop the 'size' field. Should be redundant with
- // --input_shapes and shapes propagation.
+ // size allows to specify a 1-D shape for the RNN state array.
+ // Will be expanded with 1's to fit the model.
+ // TODO(benoitjacob): should allow a generic, explicit shape.
optional int32 size = 3;
- // TODO(benoitjacob): manually_create is a temporary hack:
- // due to discrepancies between the current toco dims tracking and
- // TensorFlow shapes, for some models we need to manually create RNN state
- // arrays with a specified shape.
- // Maybe we should actually implement back-edges as operators of their own,
- // which would remove the need for much special-casing, including here,
- // we could probably consistently let PropagateFixedSizes handle state
- // arrays.
- // TODO(benoitjacob): should really drop manually_create now.
- optional bool manually_create = 4;
}
// ModelFlags encodes properties of a model that, depending on the file
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 99a54a300b..df785a5102 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -958,7 +958,9 @@ void CheckModelCounts(const Model& model) {
void MakeArrayDims(int num_dims, int batch, int height, int width, int depth,
std::vector<int>* out_dims) {
CHECK(out_dims->empty());
- if (num_dims == 1) {
+ if (num_dims == 0) {
+ return;
+ } else if (num_dims == 1) {
CHECK_EQ(batch, 1);
*out_dims = {depth};
} else if (num_dims == 2) {
@@ -990,13 +992,13 @@ void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) {
if (array.has_shape()) {
num_dims = array.shape().dimensions_count();
}
- std::vector<int> dims;
- MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
CHECK(array.data_type == ArrayDataType::kFloat ||
array.data_type == ArrayDataType::kNone);
array.data_type = ArrayDataType::kFloat;
- if (!array.has_shape()) {
+ if (!array.has_shape() && num_dims >= 0) {
Shape* shape = array.mutable_shape();
+ std::vector<int> dims;
+ MakeArrayDims(num_dims, batch, 1, 1, size, &dims);
*shape->mutable_dims() = dims;
}
}
@@ -1185,9 +1187,6 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
}
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
- if (!rnn_state.manually_create()) {
- continue;
- }
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
model);
}