diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-11 18:10:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-11 18:13:37 -0800 |
commit | 91526e1505e947fc64aece30ecfcd7ecec5de2c1 (patch) | |
tree | 0ed1618ec1e917f2cf8cf2b0a772c1b68ab146cb /tensorflow/contrib/kfac | |
parent | e202188ef741608384a3439fdbc0c4e2fb96e3f1 (diff) |
K-FAC: Utility function for scheduling N ops per global_step.
PiperOrigin-RevId: 181689879
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/utils_test.py | 38 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/utils.py | 68 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/utils_lib.py | 1 |
5 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 17458ffa2a..f4ed978174 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -122,6 +122,8 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index c8631ed89b..97a97adbf5 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -287,6 +288,43 @@ class UtilsTest(test.TestCase): tensor = array_ops.zeros([], dtype=dtypes.float32) mean = utils.cross_replica_mean(tensor) + def testBatchExecute(self): + """Ensure batch_execute runs in a round-robin fashion.""" + + def increment_var(var): + return lambda: var.assign_add(1) + + with ops.Graph().as_default(), self.test_session() as sess: + i = variable_scope.get_variable('i', initializer=0) + accumulators = [ + variable_scope.get_variable('var%d' % j, initializer=0) + for j in range(3) + ] + thunks = [increment_var(var) for var in accumulators] + increment_accumulators = utils.batch_execute(i, thunks, 2) + increment_i = i.assign_add(1) + + sess.run(variables.global_variables_initializer()) + + # Ensure one op per thunk. + self.assertEqual(3, len(increment_accumulators)) + + # Ensure round-robin execution. + values = [] + for _ in range(5): + sess.run(increment_accumulators) + sess.run(increment_i) + values.append(sess.run(accumulators)) + self.assertAllClose( + [ + [1, 1, 0], # + [2, 1, 1], # + [2, 2, 2], # + [3, 3, 2], # + [4, 3, 3] + ], + values) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index cd9dca3f02..ee6549b109 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -199,6 +199,7 @@ py_library( deps = [ "//tensorflow/contrib/tpu", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index d717f427e6..e89508fa46 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -25,6 +25,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops @@ -352,5 +353,72 @@ def ensure_sequence(obj): return (obj,) +def batch_execute(global_step, thunks, batch_size, name=None): + """Executes a subset of ops per global step. + + Given a list of thunks, each of which produces a single stateful op, + ensures that exactly 'batch_size' ops are run per global step. Ops are + scheduled in a round-robin fashion. For example, with 3 ops + + global_step | op0 | op1 | op2 + ------------+-----+-----+----- + 0 | x | x | + ------------+-----+-----+----- + 1 | x | | x + ------------+-----+-----+----- + 2 | | x | x + ------------+-----+-----+----- + 3 | x | x | + ------------+-----+-----+----- + 4 | x | | x + + Does not guarantee order of op execution within a single global step. + + Args: + global_step: Tensor indicating time. Determines which ops run. + thunks: List of thunks. Each thunk encapsulates one op. Return values are + ignored. + batch_size: int. Number of ops to execute per global_step. + name: string or None. Name scope for newly added ops. + + Returns: + List of ops. Exactly 'batch_size' ops are guaranteed to have an effect + every global step. + """ + + def true_fn(thunk): + """Ensures thunk is executed and returns an Op (not a Tensor).""" + + def result(): + with ops.control_dependencies([thunk()]): + return control_flow_ops.no_op() + + return result + + def false_fn(_): + """Executes a no-op.""" + + def result(): + return control_flow_ops.no_op() + + return result + + with ops.name_scope(name, "batch_execute"): + true_fns = [true_fn(thunk) for thunk in thunks] + false_fns = [false_fn(thunk) for thunk in thunks] + num_thunks = len(thunks) + conditions = [ + math_ops.less( + math_ops.mod(batch_size - 1 + global_step * batch_size - j, + num_thunks), batch_size) for j in range(num_thunks) + ] + result = [ + control_flow_ops.cond(condition, true_fn, false_fn) + for (condition, true_fn, + false_fn) in zip(conditions, true_fns, false_fns) + ] + return result + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index 074dc579da..cc48e3c69f 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -38,6 +38,7 @@ _allowed_symbols = [ "generate_random_signs", "fwd_gradients", "ensure_sequence", + "batch_execute", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) |