aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/cond_v2_test.py
Commit message (Collapse)AuthorAge
* cond_v2: raise an error if pred is a Python bool.Gravatar Skye Wanderman-Milne2018-10-10
| | | | | | This is to match the existing behavior of tf.cond. PiperOrigin-RevId: 216534084
* Partial support tfe.defun in tf.gradients.Gravatar Alexandre Passos2018-10-08
| | | | | | | | Doesn't attempt to deal with cases where we might have already generated the functiondef for the parent function as in that case we cannot easily modify the forward pass. PiperOrigin-RevId: 216243224
* Copy device from If op to the lowered ops.Gravatar Saurabh Saxena2018-10-05
| | | | | | Enable GPU tests for cond_v2. PiperOrigin-RevId: 215956220
* Unpack output of cond_v2 if it is a singleton to match behavior of cond.Gravatar Saurabh Saxena2018-09-24
| | | | PiperOrigin-RevId: 214381126
* Replace self.test_session(graph=<an object not None>) with ↵Gravatar A. Unique TensorFlower2018-09-24
| | | | | | self.session(graph=...) as it's the same semantic. PiperOrigin-RevId: 214286845
* Change test to use 2 CPU devices instead of GPU.Gravatar A. Unique TensorFlower2018-09-13
| | | | | | General cleanup: testDeviceInAndOutOfCond uses a GPU in a CPU only test build resulting in all operations run on the same device even though the graph is for multiple devices. PiperOrigin-RevId: 212775360
* Move from deprecated self.test_session() to self.cached_session().Gravatar A. Unique TensorFlower2018-09-13
| | | | | | | | self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212766976
* This change re-enables a few condv2 tests now that the underlying defun ↵Gravatar A. Unique TensorFlower2018-09-12
| | | | | | issues have been fixed. PiperOrigin-RevId: 212733064
* Switch cond_v2 to using tfe.defun instead of function.Defun.Gravatar Skye Wanderman-Milne2018-08-23
| | | | | | | | | | | | | This requires a few changes: - Make func_graph_from_py_func public (not part of the official public API though) - Make function_def_to_graph return the new FuncGraph implementation. - Disables some cond_v2 tests until we get them working with the new FuncGraph implementation. - Add outer_graph field to FuncGraph. - Add external_captures and internal_captures properties to FuncGraph for readability. - Remove extra_inputs/extra_args terminology from cond_v2_impl for readability. - Use compat.as_str() around Graph._functions keys. In Python 3, we were somehow getting a mix of str and bytes objects. PiperOrigin-RevId: 210015940
* Use outer_graph to unique name of then/else fns.Gravatar Jacques Pienaar2018-08-10
| | | | PiperOrigin-RevId: 208229302
* Remove identity ops for ys added during gradient computation. This was added ↵Gravatar Saurabh Saxena2018-08-08
| | | | | | | | to avoid issues with computing gradients when ys were dependent. The real issue behind that has however since been fixed so adding identity ops is no longer relevant. PiperOrigin-RevId: 207974344
* Support Defuns and nested Defuns inside cond_v2 branches.Gravatar Saurabh Saxena2018-07-19
| | | | | | Support nested cond_v2s. PiperOrigin-RevId: 205356562
* Enables `If` operator lowering in cond_v2 when XLA is disabled. Lowering ↵Gravatar A. Unique TensorFlower2018-06-18
| | | | | | | | | | | | | | | | allows cond_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of cond_v2 branches, and enabling non-strict evaluation & partial pruning of branches. This brings cond_v2 closer to feature parity with tf.cond. However, we do not lower `If` in the XLA context because it is easier for XLA to apply its own optimizations when dealing with un-lowered `If` operators than with lowered switch/merge control flow. Also adds a toggleable flag in for InlineFunctionBody in function.cc that prevents the function caller device from overriding the devices of function body nodes. This is necessary for cond_v2 branches to support explicitly-specified devices. Adds several tests to make sure that: - lowering is usually enabled - lowering is disabled for XLA - node colocation inside of cond_v2 branches works - explicit device placement inside of cond_v2 branches works PiperOrigin-RevId: 201049850
* Move cond_v2 to core (non-public) and add toggle to use cond_v2 by default.Gravatar Skye Wanderman-Milne2018-06-15
This change: * Creates a new global variable, control_flow_ops._ENABLE_COND_V2, to use cond_v2 by default when calling tf.cond. This variable can also be controlled via the environment variable TF_ENABLE_COND_V2. * Moves cond_v2 out of contrib so it's accessible from control_flow_ops.py. * Lazily "imports" some modules in cond_v2 to avoid circular dependencies. Note that these lazy "imports" must be imported by the cond_v2 caller (or recursively by one of the caller's imports) in order for cond_v2 to have access to them. * Renames the cond_v2 module to cond_v2_impl, and creates a new cond_v2 module that imports the cond_v2 method and the necessary extra imports. This is useful for explicitly calling cond_v2 outside of control_flow_ops.cond. PiperOrigin-RevId: 200778208