aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py')
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py b/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py
new file mode 100644
index 0000000000..001f9170bc
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_allreduce_test.py
@@ -0,0 +1,153 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import tensorflow as tf
+import tensorflow.contrib.mpi_collectives as mpi
+from tensorflow.python.platform import test
+
+
+average_allreduce = False
+max_wrong_count = -1
+
+
+class AllreduceTest(test.TestCase):
+ def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red,
+ our_correct):
+ # Find reduced/allreduced indices that are wrong and print all the
+ # values from output, slices, reduced, allreduced, so we can debug
+ # which is incorrect:
+ wrong_count = 0
+ red_dims = out_loc_red.shape
+ assert(len(red_dims) == 2)
+ for i in range(red_dims[0]):
+ for j in range(red_dims[1]):
+ suffix = ""
+ if out_loc_red[i][j] != my_correct[i][j] or \
+ out_all_red[i][j] != our_correct[i][j]:
+ suffix = "WRONG"
+ wrong_count += 1
+ print("{}\t{}\t{}\t{}\t{}\t{}"
+ .format(my_rank, i, j, out_loc_red[i][j],
+ out_all_red[i][j], suffix), flush=True)
+ if max_wrong_count > 0 and wrong_count >= max_wrong_count:
+ return
+
+ def test_mpi_allreduce(self):
+ # Get MPI rank
+ my_rank = int(os.environ['PMI_RANK'])
+ num_ranks = int(os.environ['PMI_SIZE'])
+
+ stages = 13
+ batch_size = 1331
+ hidden_size = batch_size
+ out_size = batch_size
+
+ # Input placeholder (batch_size x hidden) - init to 1s
+ inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size),
+ name="Input")
+
+ # Large matrices (hidden x out_dim) - init random
+ weights = []
+ for i in range(stages):
+ initer = tf.constant_initializer(pow(2.0, i + 1.0))
+ weights.append(tf.get_variable("weights_{}".format(i),
+ shape=(hidden_size, out_size),
+ dtype=tf.float32,
+ initializer=initer))
+
+ # Calculate output through dependent allreduces
+ stage_input = inputs
+ for i in range(stages):
+ inter_output = tf.add(stage_input, weights[i],
+ name="add_red_{}".format(i))
+ stage_input = mpi.allreduce(inter_output,
+ average=average_allreduce)
+
+ all_reduced = stage_input
+
+ # Local reduced output for verification
+ local_input = inputs
+ for i in range(stages):
+ inter_output = tf.add(local_input, weights[i],
+ name="addin_loc_{}".format(i))
+ my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)),
+ dtype=tf.float32, name="loc_redr_{}".format(i))
+ for r in range(num_ranks):
+ my_reducer = tf.add(my_reducer, inter_output,
+ name="add_loc_{}_{}".format(i, r))
+ if average_allreduce:
+ local_input = tf.div(my_reducer, num_ranks,
+ name="div_loc_{}".format(i))
+ else:
+ local_input = my_reducer
+
+ local_reduced = local_input
+
+ # NOTE: This assumes that device IDs are numbered the same as ranks
+ gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
+ config = tf.ConfigProto(gpu_options=gpu_options)
+
+ # MPI Session to test allreduce
+ with mpi.Session(config=config) as sess:
+ sess.run(tf.global_variables_initializer())
+
+ input_feed = np.ones((batch_size, hidden_size), dtype=np.float32)
+ our_output = input_feed[0][0]
+ spread_var = 100
+ input_feed = input_feed + my_rank * spread_var
+ my_output = input_feed[0][0]
+ for i in range(stages):
+ curr_feed = my_output + pow(2.0, i + 1.0)
+ my_output = curr_feed * num_ranks + 1
+ curr_our_feed = our_output + pow(2.0, i + 1.0)
+ if i == 0:
+ sum_ranks = num_ranks * (num_ranks - 1) / 2
+ our_output = curr_our_feed * num_ranks + \
+ spread_var * sum_ranks
+ else:
+ our_output = curr_our_feed * num_ranks
+
+ print("rank {}: My output is {}".format(my_rank, my_output))
+ my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
+ my_correct = my_correct + my_output
+ print("rank {}: Our output is {}".format(my_rank, our_output))
+ our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
+ our_correct = our_correct + our_output
+
+ for i in range(1000):
+ if i % 100 == 0:
+ print("{}: iter {}".format(my_rank, i), flush=True)
+ feed_dict = {inputs: input_feed}
+ out_all_red, out_loc_red \
+ = sess.run([all_reduced, local_reduced],
+ feed_dict=feed_dict)
+
+ if not np.allclose(out_loc_red, my_correct) or \
+ not np.allclose(out_all_red, our_correct):
+ print("Test incorrect on iter {}".format(i), flush=True)
+ self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red,
+ our_correct)
+ assert(np.allclose(out_loc_red, my_correct) and
+ np.allclose(out_all_red, our_correct))
+
+
+if __name__ == '__main__':
+ test.main()