aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/deprecated
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-28 11:01:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-28 11:05:41 -0700
commitd3d60ff6acec178b1cf912938aa6180bbd1a676f (patch)
tree11f949fdcb6f3fd68827e9fa6d261b99cdb18c01 /tensorflow/contrib/deprecated
parent863329e469fe091dae2ce5f1c6851a809ce0d579 (diff)
Merge changes from github.
END_PUBLIC --- Commit 301b14c24 authored by Skye Wanderman-Milne<skyewm@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Basic while loop gradient functionality in C++ This change introduces the basic framework to create the gradient graph of a while loop using the C++ API. This supports building the gradient graph as long as the body function of the while loop contains no ops whose gradient function requires a stack. In other words, it doesn't support gradient functions that use the input values to the op (e.g. add will work, but multiply will not). It also doesn't support nested while loops, and doesn't detect all error cases. PiperOrigin-RevId: 170243281 --- Commit 545e3572f authored by Asim Shankar<ashankar@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Datasets: Reference the programmer's guide in API docs. PiperOrigin-RevId: 170241348 --- Commit 24890d550 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 170241322 --- Commit 02d2f3760 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update ops-related pbtxt files. PiperOrigin-RevId: 170240603 --- Commit 759690f02 authored by Reed Wanderman-Milne<reedwm@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add float16 support to tf.nn.fused_batch_norm on the GPU. Scale, offset, mean, and variance must still be float32 if the input is float16. PiperOrigin-RevId: 170239448 --- Commit 20370104c authored by Igor Saprykin<isaprykin@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Support export strategies in _TrainingExecutor. One could set export strategies to the EvalSpec. An exception is raised if the type isn't export_strategy.ExportStrategy. During continuous evaluation, export strategies are going to be triggered. They in turn call Estimator's export_savedmodel. PiperOrigin-RevId: 170237073 --- Commit 56402103e authored by Reed Wanderman-Milne<reedwm@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix BFC allocator's log messages on OOM error. Before, the "Chunks in use" message and other in-use messages would always be 0. PiperOrigin-RevId: 170233715 --- Commit bc80e46b1 authored by Peter Hawkins<phawkins@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [TF:XLA] Implement BroadcastArgs. PiperOrigin-RevId: 170228025 --- Commit bced6676e authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 170204652 PiperOrigin-RevId: 170367641
Diffstat (limited to 'tensorflow/contrib/deprecated')
-rw-r--r--tensorflow/contrib/deprecated/__init__.py58
1 files changed, 29 insertions, 29 deletions
diff --git a/tensorflow/contrib/deprecated/__init__.py b/tensorflow/contrib/deprecated/__init__.py
index 0bbca8d8ed..bfea8445a7 100644
--- a/tensorflow/contrib/deprecated/__init__.py
+++ b/tensorflow/contrib/deprecated/__init__.py
@@ -18,35 +18,32 @@ For TensorFlow 1.0, we have reorganized the TensorFlow summary ops into a
submodule, and made some semantic tweaks. The first thing to note is that we
moved the APIs around as follows:
+```python
tf.scalar_summary -> tf.summary.scalar
-
tf.histogram_summary -> tf.summary.histogram
-
tf.audio_summary -> tf.summary.audio
-
tf.image_summary -> tf.summary.image
-
tf.merge_summary -> tf.summary.merge
-
tf.merge_all_summaries -> tf.summary.merge_all
+```
-We think this is a cleaner API and will improve long-term discoverability and
-clarity of the TensorFlow API. However, we also took the opportunity to make an
+We think this API is cleaner and will improve long-term discoverability and
+clarity of the TensorFlow API. But we also took the opportunity to make an
important change to how summary "tags" work. The "tag" of a summary is the
string that is associated with the output data, i.e. the key for organizing the
generated protobufs.
-Previously, the tag was allowed to be any unique string, and had no relation
+Previously, the tag was allowed to be any unique string; it had no relation
to the summary op generating it, and no relation to the TensorFlow name system.
-This made it very difficult to write re-usable code that would add summary
-ops to the graph. If you had a function that would add summary ops, you would
-need to manually pass in a name scope to that function to create deduplicated
-tags, otherwise your program would fail with a runtime error due to tag
-collision.
-
-The new summary APIs under tf.summary throw away the "tag" as an independent
-concept; instead, the first argument is the node name. So summary tags now
-automatically inherit the surrounding TF name scope, and automatically
+This behavior made it very difficult to write reusable that would add
+summary ops to the graph. If you had a function to add summary ops, you would
+need to pass in a `tf.name_scope`, manually, to that function to create
+deduplicated tags. Otherwise your program would fail with a runtime error due
+to tag collision.
+
+The new summary APIs under `tf.summary` throw away the "tag" as an independent
+concept; instead, the first argument is the node name. So summary tags now
+automatically inherit the surrounding `tf.name_scope`, and automatically
are deduplicated if there is a conflict. Now however, the only allowed
characters are alphanumerics, underscores, and forward slashes. To make
migration easier, the new APIs automatically convert illegal characters to
@@ -54,6 +51,7 @@ underscores.
Just as an example, consider the following "before" and "after" code snippets:
+```python
# Before
def add_activation_summaries(v, scope):
tf.scalar_summary("%s/fraction_of_zero" % scope, tf.nn.fraction_of_zero(v))
@@ -63,27 +61,28 @@ def add_activation_summaries(v, scope):
def add_activation_summaries(v):
tf.summary.scalar("fraction_of_zero", tf.nn.fraction_of_zero(v))
tf.summary.histogram("activations", v)
+```
Now, so long as the add_activation_summaries function is called from within the
-right name scope, the behavior is the same.
+right `tf.name_scope`, the behavior is the same.
Because this change does modify the behavior and could break tests, we can't
automatically migrate usage to the new APIs. That is why we are making the old
-APIs temporarily available here at tf.contrib.deprecated.
+APIs temporarily available here at `tf.contrib.deprecated`.
In addition to the name change described above, there are two further changes
to the new summary ops:
-- the "max_images" argument for tf.image_summary was renamed to "max_outputs
- for tf.summary.image
-- tf.scalar_summary accepted arbitrary tensors of tags and values. However,
- tf.summary.scalar requires a single scalar name and scalar value. In most
- cases, you can create tf.summary.scalars in a loop to get the same behavior
+- the "max_images" argument for `tf.image_summary` was renamed to "max_outputs
+ for `tf.summary.image`
+- `tf.scalar_summary` accepted arbitrary tensors of tags and values. But
+ `tf.summary.scalar` requires a single scalar name and scalar value. In most
+ cases, you can create `tf.summary.scalar` in a loop to get the same behavior
-As before, TensorBoard groups charts by the top-level name scope. This may
-be inconvenient, since in the new summary ops the summary will inherit that
-name scope without user control. We plan to add more grouping mechanisms to
-TensorBoard, so it will be possible to specify the TensorBoard group for
+As before, TensorBoard groups charts by the top-level `tf.name_scope` which may
+be inconvenient, for in the new summary ops, the summary will inherit that
+`tf.name_scope` without user control. We plan to add more grouping mechanisms
+to TensorBoard, so it will be possible to specify the TensorBoard group for
each summary via the summary API.
"""
@@ -99,9 +98,10 @@ from tensorflow.python.ops.logging_ops import image_summary
from tensorflow.python.ops.logging_ops import merge_all_summaries
from tensorflow.python.ops.logging_ops import merge_summary
from tensorflow.python.ops.logging_ops import scalar_summary
-# pylint: enable=unused-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
+# pylint: enable=unused-import,line-too-long
+
_allowed_symbols = ['audio_summary', 'histogram_summary',
'image_summary', 'merge_all_summaries',
'merge_summary', 'scalar_summary']