aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 06:14:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 06:18:27 -0700
commitdd52e1d30702df5dfc805a1f433061dfbb75c814 (patch)
tree87b3e75c804282d061a304c69b780f0372b5f73b /tensorflow/contrib/lite
parentc248f458c76df89fa3d608dcbe7c4c5e10962c24 (diff)
Fix test that was relying on old lax toco behavior
PiperOrigin-RevId: 215553161
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 18036fac6f..3f2255c454 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -762,8 +762,11 @@ def make_constant_tests(zip_path):
dtype=parameters["dtype"],
name="input1",
shape=parameters["input_shape"])
- out = tf.constant(
+ constant = tf.constant(
create_tensor_data(parameters["dtype"], parameters["input_shape"]))
+ # This maximum node is here to avoid the situation where a graph output is
+ # a constant, which is an error in toco.
+ out = tf.maximum(dummy_input, constant)
return [dummy_input], [out]
def build_inputs(parameters, sess, inputs, outputs):
@@ -2848,7 +2851,14 @@ def make_zeros_like_tests(zip_path):
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
- out = tf.zeros_like(input_tensor)
+ zeros = tf.zeros_like(input_tensor)
+ # This maximum node is so that toco can perform the constants-propagation
+ # through the above zeros_like, which it can't do if the output of the
+ # zeros_like as an output of the whole graphs (graph outputs can't be
+ # constants). If toco does not perform such constants-propagation then
+ # the resulting tflite graph retains the zeros_like as a Fill op, which
+ # is unsupported by TFLite, even as a custom op.
+ out = tf.maximum(zeros, input_tensor)
return [input_tensor], [out]
def build_inputs(parameters, sess, inputs, outputs):