| Commit message (Collapse) | Author | Age |
|
|
|
|
|
| |
This is to match the existing behavior of tf.cond.
PiperOrigin-RevId: 216534084
|
|
|
|
|
|
|
|
| |
Doesn't attempt to deal with cases where we might have already generated
the functiondef for the parent function as in that case we cannot easily
modify the forward pass.
PiperOrigin-RevId: 216243224
|
|
|
|
|
|
| |
Enable GPU tests for cond_v2.
PiperOrigin-RevId: 215956220
|
|
|
|
| |
PiperOrigin-RevId: 214381126
|
|
|
|
|
|
| |
self.session(graph=...) as it's the same semantic.
PiperOrigin-RevId: 214286845
|
|
|
|
|
|
| |
General cleanup: testDeviceInAndOutOfCond uses a GPU in a CPU only test build resulting in all operations run on the same device even though the graph is for multiple devices.
PiperOrigin-RevId: 212775360
|
|
|
|
|
|
|
|
| |
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212766976
|
|
|
|
|
|
| |
issues have been fixed.
PiperOrigin-RevId: 212733064
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
This requires a few changes:
- Make func_graph_from_py_func public (not part of the official public API though)
- Make function_def_to_graph return the new FuncGraph implementation.
- Disables some cond_v2 tests until we get them working with the new FuncGraph implementation.
- Add outer_graph field to FuncGraph.
- Add external_captures and internal_captures properties to FuncGraph for readability.
- Remove extra_inputs/extra_args terminology from cond_v2_impl for readability.
- Use compat.as_str() around Graph._functions keys. In Python 3, we were somehow getting a mix of str and bytes objects.
PiperOrigin-RevId: 210015940
|
|
|
|
| |
PiperOrigin-RevId: 208229302
|
|
|
|
|
|
|
|
| |
to avoid issues with computing gradients when ys were dependent.
The real issue behind that has however since been fixed so adding identity ops is no longer relevant.
PiperOrigin-RevId: 207974344
|
|
|
|
|
|
| |
Support nested cond_v2s.
PiperOrigin-RevId: 205356562
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| |
allows cond_v2 to avoid some of the limitations of Functions, allowing users to specify devices & colocation inside of cond_v2 branches, and enabling non-strict evaluation & partial pruning of branches. This brings cond_v2 closer to feature parity with tf.cond.
However, we do not lower `If` in the XLA context because it is easier for XLA to apply its own optimizations when dealing with un-lowered `If` operators than with lowered switch/merge control flow.
Also adds a toggleable flag in for InlineFunctionBody in function.cc that prevents the function caller device from overriding the devices of function body nodes. This is necessary for cond_v2 branches to support explicitly-specified devices.
Adds several tests to make sure that:
- lowering is usually enabled
- lowering is disabled for XLA
- node colocation inside of cond_v2 branches works
- explicit device placement inside of cond_v2 branches works
PiperOrigin-RevId: 201049850
|
|
This change:
* Creates a new global variable, control_flow_ops._ENABLE_COND_V2, to use
cond_v2 by default when calling tf.cond. This variable can also be
controlled via the environment variable TF_ENABLE_COND_V2.
* Moves cond_v2 out of contrib so it's accessible from control_flow_ops.py.
* Lazily "imports" some modules in cond_v2 to avoid circular dependencies.
Note that these lazy "imports" must be imported by the cond_v2 caller (or
recursively by one of the caller's imports) in order for cond_v2 to have
access to them.
* Renames the cond_v2 module to cond_v2_impl, and creates a new cond_v2 module
that imports the cond_v2 method and the necessary extra imports. This is
useful for explicitly calling cond_v2 outside of control_flow_ops.cond.
PiperOrigin-RevId: 200778208
|