diff options
author | 2018-10-02 19:13:14 -0700 | |
---|---|---|
committer | 2018-10-02 19:16:58 -0700 | |
commit | 8dc7bc7764150253c03a666eee84fc48f867d6a2 (patch) | |
tree | 47b621fdcdcb30428f342179cebbf93561cca487 /tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc | |
parent | f8ba42b0ab0bb19af0e4a930b95e7e7b3d2f557e (diff) |
In all constant-propagation transformations, check that the array we'd be turning into a constant is a
discardable array. If it's not discardable, it means that the user wants this array to keep existing
in a way that is observable to them, i.e. not as weights.
Typical example: a Fill op outputs an array that is passed as a RNN state array (non-discardable).
It seems that so far we have been relying on accidental ordering of graph transformations for such state
arrays not to be accidentally turned into constants. Instead, the desired graph transformation here is
RemoveUnusedOp noticing that such a Fill can be discarded since its output is a RNN state array.
So I don't have a test for this, but this seems to be tightening existing behavior, and should be good
to have as long as it does not regress anything.
PiperOrigin-RevId: 215500760
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc index b35c3e19c4..58d6797e1c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc @@ -96,6 +96,14 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { const SliceOperator* op = static_cast<const SliceOperator*>(base_op); CHECK_EQ(op->outputs.size(), 1); + + // If the output of this op is a non-discardable array such as an input_array + // or a state array of the model, then this is a job for RemoveUnusedOp, not + // for constants-propagation. + if (!IsDiscardableArray(*model, op->outputs[0])) { + return false; + } + auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. |