aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Russell Power <power@google.com>2017-11-13 18:19:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-13 18:24:33 -0800
commit80a3d011807e7b3a9de4d58e082acf2e091d7927 (patch)
tree74957dc1c11964adb201a17c3bc5313c13a057d7
parent8997ae6271cd2c496988ceeedab1d31755d65da4 (diff)
Add a model comparison function to TPU test utilities.
PiperOrigin-RevId: 175620458
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/test_util.py137
2 files changed, 130 insertions, 8 deletions
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index e14c36ae43..64e9d0e765 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -16,6 +16,7 @@ package(
"//cloud/vmm/testing/tests/tpu:__subpackages__",
"//learning/brain:__subpackages__",
"//tensorflow:__subpackages__",
+ "//third_party/cloud_tpu:__subpackages__",
],
)
diff --git a/tensorflow/contrib/tpu/python/tpu/test_util.py b/tensorflow/contrib/tpu/python/tpu/test_util.py
index f30c27f129..b83c72d0ff 100644
--- a/tensorflow/contrib/tpu/python/tpu/test_util.py
+++ b/tensorflow/contrib/tpu/python/tpu/test_util.py
@@ -18,14 +18,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.tpu.python.tpu import tpu
+import os.path
+import pickle
+import tempfile
+
+import numpy as np
-from tensorflow.python.client import session
+from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.contrib.tpu.python.tpu import tpu_config
+from tensorflow.contrib.tpu.python.tpu import tpu_estimator
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import saver as tf_saver
def has_tpu():
@@ -38,8 +50,9 @@ def has_tpu():
Returns:
boolean, True if a TPU device is available, otherwise False.
"""
+
def _check():
- with session.Session() as sess:
+ with tf_session.Session() as sess:
sess.run(tpu.initialize_system())
sess.run(tpu.shutdown_system())
@@ -61,6 +74,111 @@ def _available_devices():
return tuple(devices)
+def copy_dir(src, tgt):
+ """Copy src to tgt."""
+ gfile.MakeDirs(tgt)
+ seen_dirs = set()
+ for dirname, _, files in gfile.Walk(src):
+ for f in files:
+ src_f = os.path.join(dirname, f)
+ tgt_f = src_f.replace(src, tgt)
+ tgt_d = os.path.dirname(tgt_f)
+ if tgt_d not in seen_dirs:
+ gfile.MkDir(tgt_d)
+ seen_dirs.add(tgt_d)
+ gfile.Copy(src_f, tgt_f, overwrite=True)
+
+
+def compare_model(model_fn, input_fn, params, master="local", temp_dir=None,
+ tolerance=1e-4):
+ """Compare the results of running `model_fn` on the TPU and CPU."""
+ if not temp_dir:
+ temp_dir = tempfile.mkdtemp()
+
+ cpu_model_dir = "%s/cpu-model" % temp_dir
+ tpu_model_dir = "%s/tpu-model" % temp_dir
+ initial_model_dir = "%s/initial-model" % temp_dir
+
+ logging.info("Checkpoints and weights will be written to %s", temp_dir)
+
+ num_steps = 1
+ num_shards = 8
+
+ def _make_run_config(model_dir):
+ return tpu_config.RunConfig(
+ master=master,
+ model_dir=model_dir,
+ save_checkpoints_secs=10000,
+ session_config=config_pb2.ConfigProto(
+ allow_soft_placement=True, log_device_placement=False),
+ tpu_config=tpu_config.TPUConfig(
+ iterations_per_loop=num_steps,
+ num_shards=num_shards,
+ ),
+ )
+
+ def _make_estimator(use_tpu, model_dir):
+ return tpu_estimator.TPUEstimator(
+ model_fn=model_fn,
+ use_tpu=use_tpu,
+ config=_make_run_config(model_dir),
+ train_batch_size=num_shards,
+ params=dict(params, use_tpu=use_tpu),
+ )
+
+ def _extract_weights(checkpoint):
+ """Extract model weights from the given checkpoint file."""
+ weights = {}
+ graph = ops.Graph()
+ with graph.as_default():
+ model_fn(
+ *input_fn(params),
+ params=dict(params, use_tpu=False),
+ mode=model_fn_lib.ModeKeys.TRAIN)
+ saver = tf_saver.Saver()
+ with tf_session.Session(graph=graph) as sess:
+ saver.restore(sess, checkpoint)
+ all_vars = []
+ all_vars.extend(graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
+ all_vars.extend(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+ all_vars.extend(graph.get_collection(ops.GraphKeys.MODEL_VARIABLES))
+
+ for var in all_vars:
+ weights[var.name] = sess.run(var)
+ return weights
+
+ def _run_step(use_tpu, model_dir):
+ est = _make_estimator(use_tpu=use_tpu, model_dir=model_dir)
+ est.train(input_fn=input_fn, steps=num_steps)
+ weights = _extract_weights(est.latest_checkpoint())
+ with gfile.Open(temp_dir + "tpu-%d.weights" % use_tpu, "wb") as f:
+ f.write(pickle.dumps(weights))
+ return weights
+
+ # initialize models to the same weights by running a single step on the CPU
+ _run_step(use_tpu=False, model_dir=initial_model_dir)
+
+ copy_dir(initial_model_dir, cpu_model_dir)
+ cpu_weights = _run_step(use_tpu=False, model_dir=cpu_model_dir)
+
+ copy_dir(initial_model_dir, tpu_model_dir)
+ tpu_weights = _run_step(use_tpu=True, model_dir=tpu_model_dir)
+
+ bad_weights = False
+ for k in cpu_weights:
+ if k not in tpu_weights:
+ raise KeyError("Missing weight %s from TPU checkpoint.", k)
+
+ if not np.allclose(
+ cpu_weights[k], tpu_weights[k], rtol=tolerance, atol=tolerance):
+ bad_weights = True
+ logging.error("Weights for layer %s have diverged.", k)
+
+ if bad_weights:
+ raise ValueError("Some weights have diverged. Output pickle files have "
+ "been written to %s for inspection." % temp_dir)
+
+
class TPUTestCase(test_util.TensorFlowTestCase):
"""Adds helpers for testing on TPU devices to `TensorFlowTestCase`.
@@ -68,7 +186,7 @@ class TPUTestCase(test_util.TensorFlowTestCase):
```
def model_fn(features):
- return tf.reduce_sum(features * 2)
+ return tf.reduce_sum(features * 2)
class ModelTests(test_util.TPUTestCase):
def test_sum(self):
@@ -97,10 +215,10 @@ class TPUTestCase(test_util.TensorFlowTestCase):
Returns:
Output from the model function.
"""
+
def _make_placeholders():
- return dict(
- [(gen_array_ops.placeholder_with_default(v, v.shape), v)
- for v in model_inputs])
+ return dict([(gen_array_ops.placeholder_with_default(v, v.shape), v)
+ for v in model_inputs])
if device == "tpu":
with self.test_session(graph=ops.Graph()) as sess:
@@ -133,7 +251,10 @@ class TPUTestCase(test_util.TensorFlowTestCase):
else:
self.assertAllCloseAccordingToType(actual_outputs, expected_outputs)
- def assert_device_output(self, model_fn, model_inputs, expected_outputs,
+ def assert_device_output(self,
+ model_fn,
+ model_inputs,
+ expected_outputs,
devices=("cpu", "gpu", "tpu")):
"""Run `model_fn` on the given devices.