aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/generate_examples.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/generate_examples.py')
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 86540d58a6..5bca82ded0 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -36,6 +36,11 @@ import traceback
import zipfile
import numpy as np
from six import StringIO
+
+# TODO(aselle): Disable GPU for now
+os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+
+# pylint: disable=g-import-not-at-top
import tensorflow as tf
from google.protobuf import text_format
# TODO(aselle): switch to TensorFlow's resource_loader
@@ -379,12 +384,13 @@ def make_zip_of_tests(zip_path,
report["toco_log"] = ""
tf.reset_default_graph()
- try:
- inputs, outputs = make_graph(param_dict_real)
- except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
- ValueError):
- report["tf_log"] += traceback.format_exc()
- return None, report
+ with tf.device("/cpu:0"):
+ try:
+ inputs, outputs = make_graph(param_dict_real)
+ except (tf.errors.UnimplementedError, tf.errors.InvalidArgumentError,
+ ValueError):
+ report["tf_log"] += traceback.format_exc()
+ return None, report
sess = tf.Session()
try: