diff options
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_examples.py')
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 86540d58a6..5bca82ded0 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -36,6 +36,11 @@ import traceback import zipfile import numpy as np from six import StringIO + +# TODO(aselle): Disable GPU for now +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +# pylint: disable=g-import-not-at-top import tensorflow as tf from google.protobuf import text_format # TODO(aselle): switch to TensorFlow's resource_loader @@ -379,12 +384,13 @@ def make_zip_of_tests(zip_path, report["toco_log"] = "" tf.reset_default_graph() - try: - inputs, outputs = make_graph(param_dict_real) - except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, - ValueError): - report["tf_log"] += traceback.format_exc() - return None, report + with tf.device("/cpu:0"): + try: + inputs, outputs = make_graph(param_dict_real) + except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError, + ValueError): + report["tf_log"] += traceback.format_exc() + return None, report sess = tf.Session() try: |