aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 19:13:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 19:16:58 -0700
commit8dc7bc7764150253c03a666eee84fc48f867d6a2 (patch)
tree47b621fdcdcb30428f342179cebbf93561cca487 /tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
parentf8ba42b0ab0bb19af0e4a930b95e7e7b3d2f557e (diff)
In all constant-propagation transformations, check that the array we'd be turning into a constant is a
discardable array. If it's not discardable, it means that the user wants this array to keep existing in a way that is observable to them, i.e. not as weights. Typical example: a Fill op outputs an array that is passed as a RNN state array (non-discardable). It seems that so far we have been relying on accidental ordering of graph transformations for such state arrays not to be accidentally turned into constants. Instead, the desired graph transformation here is RemoveUnusedOp noticing that such a Fill can be discarded since its output is a RNN state array. So I don't have a test for this, but this seems to be tightening existing behavior, and should be good to have as long as it does not regress anything. PiperOrigin-RevId: 215500760
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
index b35c3e19c4..58d6797e1c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
@@ -96,6 +96,14 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
CHECK_EQ(op->outputs.size(), 1);
+
+ // If the output of this op is a non-discardable array such as an input_array
+ // or a state array of the model, then this is a job for RemoveUnusedOp, not
+ // for constants-propagation.
+ if (!IsDiscardableArray(*model, op->outputs[0])) {
+ return false;
+ }
+
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.