aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-11 18:10:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-11 18:13:37 -0800
commit91526e1505e947fc64aece30ecfcd7ecec5de2c1 (patch)
tree0ed1618ec1e917f2cf8cf2b0a772c1b68ab146cb /tensorflow/contrib/kfac
parente202188ef741608384a3439fdbc0c4e2fb96e3f1 (diff)
K-FAC: Utility function for scheduling N ops per global_step.
PiperOrigin-RevId: 181689879
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py38
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py68
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py1
5 files changed, 110 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index 17458ffa2a..f4ed978174 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -122,6 +122,8 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:random_seed",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
index c8631ed89b..97a97adbf5 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -287,6 +288,43 @@ class UtilsTest(test.TestCase):
tensor = array_ops.zeros([], dtype=dtypes.float32)
mean = utils.cross_replica_mean(tensor)
+ def testBatchExecute(self):
+ """Ensure batch_execute runs in a round-robin fashion."""
+
+ def increment_var(var):
+ return lambda: var.assign_add(1)
+
+ with ops.Graph().as_default(), self.test_session() as sess:
+ i = variable_scope.get_variable('i', initializer=0)
+ accumulators = [
+ variable_scope.get_variable('var%d' % j, initializer=0)
+ for j in range(3)
+ ]
+ thunks = [increment_var(var) for var in accumulators]
+ increment_accumulators = utils.batch_execute(i, thunks, 2)
+ increment_i = i.assign_add(1)
+
+ sess.run(variables.global_variables_initializer())
+
+ # Ensure one op per thunk.
+ self.assertEqual(3, len(increment_accumulators))
+
+ # Ensure round-robin execution.
+ values = []
+ for _ in range(5):
+ sess.run(increment_accumulators)
+ sess.run(increment_i)
+ values.append(sess.run(accumulators))
+ self.assertAllClose(
+ [
+ [1, 1, 0], #
+ [2, 1, 1], #
+ [2, 2, 2], #
+ [3, 3, 2], #
+ [4, 3, 3]
+ ],
+ values)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index cd9dca3f02..ee6549b109 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -199,6 +199,7 @@ py_library(
deps = [
"//tensorflow/contrib/tpu",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
index d717f427e6..e89508fa46 100644
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ b/tensorflow/contrib/kfac/python/ops/utils.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@@ -352,5 +353,72 @@ def ensure_sequence(obj):
return (obj,)
+def batch_execute(global_step, thunks, batch_size, name=None):
+ """Executes a subset of ops per global step.
+
+ Given a list of thunks, each of which produces a single stateful op,
+ ensures that exactly 'batch_size' ops are run per global step. Ops are
+ scheduled in a round-robin fashion. For example, with 3 ops
+
+ global_step | op0 | op1 | op2
+ ------------+-----+-----+-----
+ 0 | x | x |
+ ------------+-----+-----+-----
+ 1 | x | | x
+ ------------+-----+-----+-----
+ 2 | | x | x
+ ------------+-----+-----+-----
+ 3 | x | x |
+ ------------+-----+-----+-----
+ 4 | x | | x
+
+ Does not guarantee order of op execution within a single global step.
+
+ Args:
+ global_step: Tensor indicating time. Determines which ops run.
+ thunks: List of thunks. Each thunk encapsulates one op. Return values are
+ ignored.
+ batch_size: int. Number of ops to execute per global_step.
+ name: string or None. Name scope for newly added ops.
+
+ Returns:
+ List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
+ every global step.
+ """
+
+ def true_fn(thunk):
+ """Ensures thunk is executed and returns an Op (not a Tensor)."""
+
+ def result():
+ with ops.control_dependencies([thunk()]):
+ return control_flow_ops.no_op()
+
+ return result
+
+ def false_fn(_):
+ """Executes a no-op."""
+
+ def result():
+ return control_flow_ops.no_op()
+
+ return result
+
+ with ops.name_scope(name, "batch_execute"):
+ true_fns = [true_fn(thunk) for thunk in thunks]
+ false_fns = [false_fn(thunk) for thunk in thunks]
+ num_thunks = len(thunks)
+ conditions = [
+ math_ops.less(
+ math_ops.mod(batch_size - 1 + global_step * batch_size - j,
+ num_thunks), batch_size) for j in range(num_thunks)
+ ]
+ result = [
+ control_flow_ops.cond(condition, true_fn, false_fn)
+ for (condition, true_fn,
+ false_fn) in zip(conditions, true_fns, false_fns)
+ ]
+ return result
+
+
# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
index 074dc579da..cc48e3c69f 100644
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py
@@ -38,6 +38,7 @@ _allowed_symbols = [
"generate_random_signs",
"fwd_gradients",
"ensure_sequence",
+ "batch_execute",
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)