aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
Commit message (Collapse)AuthorAge
...
* [XLA] Remove the notion of a "parameter name" separate from the ↵Gravatar Justin Lebar2017-12-14
| | | | | | | | | | instruction's name. Also set the instruction's name in the HLO parser, so that after parsing, the instructions have the names they're given in the input string. PiperOrigin-RevId: 179119003
* [XLA] Gather the bool parameters into one thing to control the text format.Gravatar A. Unique TensorFlower2017-12-14
| | | | PiperOrigin-RevId: 179079727
* Changed to allow removing side-effect instructions from an HLO computationGravatar HyoukJoong Lee2017-11-28
| | | | | | and moved the condition to the hlo_dce pass. PiperOrigin-RevId: 177215395
* When constructing fusion computations from a proto, do not uniquify the ↵Gravatar Mark Heffernan2017-11-16
| | | | | | names. The names are already unique and uniquifying them again will mutate them resulting in inconsistent names between the proto and the constructed HLO. PiperOrigin-RevId: 176035108
* Hlo parser: support fusion.Gravatar A. Unique TensorFlower2017-11-13
| | | | | | | | | Also, - Add a HloInstruction::CreateFusion interface that creates a fusion instruction with given fusion computation. Add a HloComputation::SetFusionInstruction interface to help do that. - Change how we print fusion kind. Before this change we print fusion kind together with the opcode, e.g., fusion:kLoop, which is not easy to parse. Now we append fusion kind as an attribute. - Print fusion computation the same way as other computations, instead of nested in an instruction. PiperOrigin-RevId: 175621768
* Rollback copy insertion change because it results in a DCHECK with an ↵Gravatar Mark Heffernan2017-11-03
| | | | | | | | | | | internal model. END_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 174423881 PiperOrigin-RevId: 174505237
* [TF:XLA] Improve support for const HLO visitors.Gravatar A. Unique TensorFlower2017-11-03
| | | | | | Add missing const overloads of Accept methods. PiperOrigin-RevId: 174500495
* Rewrite CopyInsertion to use module-scoped HloAliasAnalysis. The net effect ↵Gravatar Mark Heffernan2017-11-02
| | | | | | | | | | | | | | | | | | | | (number of copies inserted) is roughly similar to the existing implementation, but the new implementation is much more general. The new implementation can handle entry argument buffer reuse with minimal modification, for example. Some unnecessary copies are still added due to deficiencies in buffer assignment (b/62548313), but these can be removed when buffer assignment also uses HloAliasAnalysis. Also address a few issues uncovered with this cl: (1) For inplace dynamic slice in llvm backends, truncate do not wrap the slice. This matches the behavior of the non-inplace variant. (2) Disable SelectBetweenPredTuples test on GPU. The test introduces top-level buffer ambiguity which is not tolerated by the gpu backend. (3) When deserializing HLO form a proto, do not uniquify instruction names in fused computations. (4) In dataflow analysis, don't deallocate deleted HloValues during propagation. (5) In dataflow analysis, fix issue with live_out_of_computation property. PiperOrigin-RevId: 174423881
* [XLA] Add dead tuple elem removal to WhileLoopSimplifier.Gravatar Justin Lebar2017-11-02
| | | | | | | | | | | | | Specifically, if a while loop has tuple element that - is not used by the while condition, and - is not used by the while body, except to pass it along to the next iteration of the loop, then we can reshape the while loop's computations to eliminate this tuple element. PiperOrigin-RevId: 174413683
* [TF:XLA] Add a const HLO visitor.Gravatar A. Unique TensorFlower2017-11-02
| | | | | | Use it in the HLO cost analysis pass. PiperOrigin-RevId: 174411043
* Fixed HloComputation/HloInstruction clone to allow deep clone, and avoid the ↵Gravatar A. Unique TensorFlower2017-11-01
| | | | | | cloned instruction and computations to still have live link to their parent original modules and computations. PiperOrigin-RevId: 174271432
* [XLA] Allow full dumps of constant values via boolean parameter.Gravatar Chris Leary2017-11-01
| | | | PiperOrigin-RevId: 174257660
* Supported in this CL:Gravatar A. Unique TensorFlower2017-10-30
| | | | | | | | | | * Attaching sharding descriptors to HLO ops * Partitioning the HLO graph into per-device computations based on those sharding descriptors. * All operator support for device placement and ops replicated on all devices. * Elementwise op support for tiled shardings. * 2D Convolution support for tiled shardings (no stride or dilation support). PiperOrigin-RevId: 173946036
* Support more instructions in Hlo parser.Gravatar A. Unique TensorFlower2017-10-24
| | | | | | | | | | | - while, tuple, send/recv, get-tuple-element, call. - "device=" Also, - Change HloModule::ToString to print computations in post order, so that a computation is defined before it's used. - Add % before computation name when it's used. PiperOrigin-RevId: 173350323
* Add a recursive descent parser for the HloModule string. It constructs an ↵Gravatar A. Unique TensorFlower2017-10-19
| | | | | | | | | | | | | | | HloModule object from a string printed by HloModule::ToString(). This is a initial stage. It currently supports: - unary, binary, ternary ops, and other ops that don't have extra attributes. - module with entry computation only. - simple cases for constant instruction. To make the parser simpler, this cl removes a whitespace and adds a '%' before the computation name in HloComputation::ToString(). Further steps will enable parsing subcomputations, more cases of constants, tuple, and ops that require extra attributes (e.g., broadcast dimensions, subcomputation). PiperOrigin-RevId: 172804214
* `tf.py_func`: Handle NumPy arrays of np.object that hold unicode strings.Gravatar Derek Murray2017-10-19
| | | | | | | This also fixes a bug affecting `tf.data.Dataset.from_generator()` on Python 3, where the generator yields Unicode (i.e. default) strings. PiperOrigin-RevId: 172798007
* Add a recursive descent parser for the HloModule string. It constructs an ↵Gravatar A. Unique TensorFlower2017-10-19
| | | | | | | | | | | | | | | HloModule object from a string printed by HloModule::ToString(). This is a initial stage. It currently supports: - unary, binary, ternary ops, and other ops that don't have extra attributes. - module with entry computation only. - simple cases for constant instruction. To make the parser simpler, this cl removes a whitespace and adds a '%' before the computation name in HloComputation::ToString(). Further steps will enable parsing subcomputations, more cases of constants, tuple, and ops that require extra attributes (e.g., broadcast dimensions, subcomputation). PiperOrigin-RevId: 172804214
* Make the HLO proto representation (hlo.proto) full fidelity. Hlo modules can ↵Gravatar Mark Heffernan2017-10-13
| | | | | | | | be serialized to HLO protos and deserialized without any information loss. As part of this change, a bug is fixed in NameUniquer. Previously, passing names with numeric suffixes could result in name collisions. PiperOrigin-RevId: 172161360
* [XLA] Make it possible to inline calls to side-effecting computations.Gravatar Chris Leary2017-09-29
| | | | PiperOrigin-RevId: 170515496
* [XLA] Make HloComputation::instructions() return a view of HloInstruction*s.Gravatar Justin Lebar2017-09-28
| | | | | | | | Currently it returns a view of unique_ptr<HloInstruction>s. But the fact that these are unique_ptrs is an implementation detail, and it's ugly to leak it everywhere. PiperOrigin-RevId: 170445375
* [XLA] Replace HloComputation::ReplaceUsesOfInstruction with ↵Gravatar Justin Lebar2017-09-27
| | | | | | | | | | | | | | | HloInstruction::ReplaceAllUsesWith. RAUW used to be *almost* synonymous with RUOI, except RAUW didn't update the computation's root. This was a dangerous footgun -- if you accidentally called RAUW when you wanted RUOI (which you almost always did), your code would work perfectly, except when the relevant node happened to be the root of a computation. This change simplifies our APIs so there's just one Right Way To Do It, by making RAUW update the computation. PiperOrigin-RevId: 170290230
* [XLA] Refactor parent-fusion-instruction pointer into HloComputation, not ↵Gravatar A. Unique TensorFlower2017-08-30
| | | | | | | | | | | | HloInstruction. Presently, each instruction inside a fusion computation contains a pointer to the fusion instruction that contains the computation, which is redundant since this is common across the entire computation. This leads to lots of places where this pointer must be set when adding an instruction to the fusion computation (and bugs such as b/65177535 when one is missed), as well as code to check that it's set correctly. In addition, this is simply unnecessary data bloat. Moreover, the computation itself does not contain a pointer to the fusion instruction that references it, which leads to odd circumlocutions in the HloComputation code that retrieve the fusion instruction from the computation's root instruction. Thus, this CL moves this pointer into the HloComputation class (replacing the is_fusion_computation_ bool value), and refactor the uses as necessary. PiperOrigin-RevId: 167039280
* [XLA] Teach HloComputation::Reparent to properly handle reparenting into ↵Gravatar A. Unique TensorFlower2017-08-30
| | | | | | | | fusion computations. This also moves HloInstruction::CheckFusionInstruction() out of "private", and adds calls to it in the reduce-precision-insertion test to confirm that the reduce-precision-insertion pass maintains valid fusion computations. (These checks then fail without the fix to HloComputation::Reparent.) PiperOrigin-RevId: 167031741
* Add option to HloComputation::DeepCopyInstruction for selectively copying onlyGravatar Mark Heffernan2017-08-25
| | | | | | | certain indices. Also, add mechanism for returning the kCopy instructions added to create the deep copy. PiperOrigin-RevId: 166521917
* Merge changes from github.Gravatar Jonathan Hseu2017-08-25
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | END_PUBLIC --- Commit b30ce4714 authored by James Qin<jamesqin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Revamp CudnnRNN Saveables 1. Use a lossy way to save/restore cudnn biases during checkpointing. Cudnn uses 2 biases each gate for all RNNs while tf uses one. To allow cudnn checkpoints to be compatible with both Cudnn and platform-independent impls, previously both individual bias and summed biases each gate were stored. The new way only stores the bias sum for each gate, and split it half-half when restoring from a cudnn graph. Doing this does not cause problems since RNNs do not use weight-decay to regularize. 2. Use inheritance instead of branching * Split RNNParamsSaveable to 1 base class and 4 subclasses. * Extract common routines and only overwrite rnn-type-specific pieces in subclasses. PiperOrigin-RevId: 166413989 --- Commit ebc421daf authored by Alan Yee<alyee@ucsd.edu> Committed by Jonathan Hseu<vomjom@vomjom.net>: Update documentation for contrib (#12424) * Update __init__.py Remove ## for standardization of api docs * Create README.md Add README to define this directory's purpose * Update __init.py Markdown styling does not show up well in api docs * Update README.md Add short mention of describing what to deprecate * Update README.md Capitalize title * Update README.md Revert README change * Delete README.md --- Commit fd295394d authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use latest version of nsync library, which now allows use of cmake on MacOS. PiperOrigin-RevId: 166411437 --- Commit 587d728e0 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Refactor reduce-precision-insertion filters, add several more options. In particular, this adds the ability to add reduce-precision operations after fusion nodes based on the contents of those fusion nodes, and the ability to filter operations based on the "op_name" metadata. PiperOrigin-RevId: 166408392 --- Commit 3142f8ef5 authored by Ali Yahya<alive@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Steps toward making ResourceVariables compatible with Eager. This change forces the value of the reuse flag in variable scopes to be tf.AUTO_REUSE when in Eager mode. This change also adds comprehensive Eager tests for ResourceVariable. PiperOrigin-RevId: 166408161 --- Commit b2ce45150 authored by Igor Ganichev<iga@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make Graph::IsValidNode public It can be reimplemented with existing public APIs, but instead of doing so, making this one public seems better. PiperOrigin-RevId: 166407897 --- Commit 0a2f40e92 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA::CPU] Fix HLO profiling in parallel CPU backend. PiperOrigin-RevId: 166400211 --- Commit c4a58e3fd authored by Yao Zhang<yaozhang@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Identify frame ids for all nodes in a graph. PiperOrigin-RevId: 166397615 --- Commit 989713f26 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 166294015 PiperOrigin-RevId: 166521502
* Add option to HloComputation::DeepCopyInstruction for selectively copying onlyGravatar Mark Heffernan2017-08-25
| | | | | | | certain indices. Also, add mechanism for returning the kCopy instructions added to create the deep copy. PiperOrigin-RevId: 166521917
* Merging sibling fusion instruction using multi_output_fusionGravatar A. Unique TensorFlower2017-08-10
| | | | PiperOrigin-RevId: 164920220
* Consider the nested computations when checking if an instruction isGravatar HyoukJoong Lee2017-08-09
| | | | | | | removable from a computation. This is to prevent DCE from removing a while instruction that includes a send/recv instruction. PiperOrigin-RevId: 164722478
* Assign unique ids at the HloModule level to each HloInstruction object.Gravatar Jeffrey A. Dean2017-08-02
| | | | | | | | | | | | | | Use these when doing DFS over a graph in order to store the visited bits using an array of two-bit values (in the dfs_hlo_visitor.{h,cc} module), rather than a significantly larger and more expensive hash table to store this state. Ids are initially -1 and are assigned when unique names are assigned to the HloInstruction objects. Speeds up compilation of a convolutional image model by ~5.3% PiperOrigin-RevId: 164050902
* [XLA:CPU] Support for CPU outfeed and a xfeed (infeed/outfeed) test.Gravatar A. Unique TensorFlower2017-07-11
| | | | | | Note: does not yet support nested tuples, for symmetry with the current infeed limitations. PiperOrigin-RevId: 161502502
* [XLA] Add reasonable error messages to Builder::Build for bad parameter numbers.Gravatar A. Unique TensorFlower2017-07-06
| | | | PiperOrigin-RevId: 161136262
* Make HloInstruction names unique with an HloModule.Gravatar A. Unique TensorFlower2017-07-04
| | | | | | | This will allow most uses of HloInstruction::FullyQualifiedName() to be replaced with HloInstruction::Name(). PiperOrigin-RevId: 160879036
* [XLA] Move HLO reachability into its own file and make update-able.Gravatar Mark Heffernan2017-06-29
| | | | | | As part of the CL, change the underlying representation in the reachability map to BitVectors which allows efficient update by OR'ing the vectors together. PiperOrigin-RevId: 160591849
* [XLA] Several fixes to HLO reachability analysis.Gravatar Mark Heffernan2017-06-28
| | | | | | | | (1) Account for control dependencies in reachability. (2) Invert sense of reachability. We draw our HLO graphs with arrows from producers to consumers so it makes more sense for reachability to be defined along the direction of these edges. (3) Rename ComputeTransitiveOperands to ComputeReachability. PiperOrigin-RevId: 160366307
* Minor cleanup: Add braces around if statement arms; remove redundant ↵Gravatar A. Unique TensorFlower2017-06-06
| | | | | | "return" and "static". PiperOrigin-RevId: 158143418
* Minor cleanup: Remove unused BUILD dependencies and unnecessary code.Gravatar A. Unique TensorFlower2017-06-02
| | | | PiperOrigin-RevId: 157837211
* [TF:XLA] preserve metadata when replacing HLO instructions.Gravatar Eric Liu2017-05-30
| | | | | | | | The motivation is to add metadata for HLO instructions that are created to replace existing HLO instructions during optimizations. The assumption is that the old instruction and the new instruction would perform the same function, and that they would be correlated to the same TF op. This might not always be correct since HLO optimizations can cross TF op boundaries. But still this seems to be better than nothing. Note that this still doesn't fully resolve missing OpMetadata after HLO optimizations; new instructions might be added without using ReplaceInstruction. PiperOrigin-RevId: 157484394
* Add debug protos that serialize HLO graph information.Gravatar A. Unique TensorFlower2017-05-25
| | | | | | | Also add flags to dump this data in JSON format, for each backend. This is useful for upcoming debugging tools. PiperOrigin-RevId: 157178357
* [XLA] Avoid accumulating '%' in front of fusion parameter names.Gravatar A. Unique TensorFlower2017-05-17
| | | | PiperOrigin-RevId: 156357948
* [XLA] Various HLO naming fixes.Gravatar Mark Heffernan2017-04-27
| | | | | | | | | | | | | | | This change includes a number of fixes to HLO instruction especially fusion instructions. Specific changes: (1) Remove HloInstruction::set_name and HloComputation::set_name. These methods were a bit dangerous as it made easy to create non-unique HLO names. Replace it with UniquifyName which renames the object to a unique name based on its current name. (2) Change the name of the fusion computations to "fused_computation". Previously it was named after the root. (3) Change naming of fusion parameters. They are now named after the unfused-instructions whose values they represent. Also, previously superfluous ".1", ".2", etc, could be added to the parameter names. This change fixes that. (4) Change naming of instructions in fusion computations to be identical to the instructions they were cloned from. Previously all fused instructions would end up having a .clone suffix. (4) If HloInstruction::Clone() is called with an empty suffix, then don't add a "." to the name. Change: 154454938
* [XLA:HLO] Don't remove HLO instructions with control dependencies in DCE.Gravatar Mark Heffernan2017-04-25
| | | | | | | Removing instructions with control dependencies may result in a violation of the ordering constraint which the control dependency is intended to enforce. Change: 154190963
* [XLA] Fix the parameter instruction printing issueGravatar A. Unique TensorFlower2017-04-20
| | | | | | | | Append the parameter number to the fusion parameter name, and use the parameter name rather the instruction name in creating the new parameter. Show the paramameter number when printing out parameter instructions. Change: 153752424
* [XLA] Fix the indentation for printing fusion computation.Gravatar A. Unique TensorFlower2017-04-18
| | | | Change: 153502127
* [XLA] Represent fusion instructions as a HloComputationGravatar A. Unique TensorFlower2017-04-17
| | | | | | | | | | Using a HloComputation to represent the HloInstructions inside a fusion instruction. All the interfaces are kept the same except for the parent field of the fusion instruction. It now points to the newly created HloComputation rather the enclosing computation for the fusion instruction. Change: 153390245
* [XLA] Flatten computation call graphGravatar A. Unique TensorFlower2017-04-14
| | | | | This CL clones computations that are called from >1 call sites in a sequential context (call, while nodes) so that the call graph becomes a tree. Change: 153183115
* [XLA] Break fix for compiler scoping error in OSS compiler.Gravatar A. Unique TensorFlower2017-03-28
| | | | Change: 151518190
* [XLA] Add HLO verifier that checks HLO instruction's parent computation.Gravatar A. Unique TensorFlower2017-03-28
| | | | Change: 151494158
* Improve fused instruction dumpGravatar A. Unique TensorFlower2017-03-17
| | | | Change: 150460833
* [XLA] Replace uses of std::set with std::vector.Gravatar Mark Heffernan2017-03-10
| | | | | std::set is slow and the iteration order is unstable. A couple other opportunistic changes include consolidating all called computations of an instruction in a single vector. This faciliates fast access to all called computations. Also, replace AddControlSuccessor/Predecessor with Add/RemoveControlDepedencyTo which is less error prone as you can't create a half connected control edge. Change: 149810889
* [TF:XLA] Reduce sequential memory usage via better ordering and simulated heap.Gravatar A. Unique TensorFlower2017-03-02
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | The choice of instruction ordering, and the minimization of fragmentation once we've chosen an order, are two large inter-related factors wrt overall memory usage. The approach in this CL uses heuristics to do better on both, but neither problem is completely solved. To pick a better an ordering (the larger factor), the approach is to try the original list-scheduler based ordering, and to also try a DFS based ordering. We pick the ordering that yields a smaller minimum memory, computed with the simulated heap, ignoring fragmentation. Note that this is the absolute minimum memory for a given ordering. To minimize fragmentation, the approach is to run a heap simulation on temporary buffers. We still try to re-use existing allocations when possible, but instead of creating new allocations for temp buffers, we collect all the leftovers and use a heap to pack them. The heap algorithm that gave the best results is "lazy best-fit"; a variant of traditional best-fit that sometimes delays offset assignment until Free is called, in the hopes of yielding larger free chunks. Here's some measurements of the temp buffer sizes for GNMT encoder training (a stacked LSTM). Lower is better. I've tried various combinations of instruction ordering and heap simulation, to show the joint impact of these two factors. List-scheduler order, no heap simulation 33.33GiB List-scheduler order, with heap simulation 25.09GiB Minimized DFS order, no heap simulation 16.59GiB Arbitrary DFS order, no heap simulation 15.05GiB (old) Arbitrary DFS order, with heap simulation 12.57GiB Minimized DFS order, with heap simulation 11.71GiB (new) Note that the original list scheduler order is much worse than DFS on stacked LSTMs, but (not shown here) is much better than DFS on convolutions like Inception. Also note that heap simulation packs things tighter for all instruction orders in this example, but to varying degrees. Change: 149049028