aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-09 17:40:18 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-09 17:40:18 -0800
commit27259353e50e6bcaeeedbc26dc3aaaa5695fe500 (patch)
treec8710d98861fe7ca767059faff0ad44858869d92 /tensorflow
parent5de908567355337fdebd997fb5c60993cbe9ba2e (diff)
TensorFlow: Upstream changes from git.
Change 109849574 Avoid some missing return warnings Change 109837783 Add invalid aggregation to error message. Change 109835474 Improves docstring of RegisterGradient decorator. Fixes a typo (input -> output) and uses lowercase name for neg in the provided example. Change 109834486 Update generated Op docs. Change 109830497 Fix per_image_whitening to handle edge case by preventing the sqrt() of a negative number which is possible due to numerical floating point issues. Unit test added. Fixes #411 Change 109824286 Change TensorBoard/TAG to 4 Change 109824197 Update tutorials and documentation usage of print to use print as function not statement. This way you can copy+paste code in a python3 context and it will still work. Change 109824020 Fix another case where TensorBoard discards values after a restart. We also need to not discard on graph_def, since user code or SummaryWriter may add graph_defs at step 0 after every restart. Change 109821924 Defines Templates for variable sharing. A Template is a function that generates a sub-graph with the same variables each time it is called. Two different templates defined with the same underlying function also return different variables. Change 109815801 Don't instatiate the eigen expressions for additions and subtractions of boolean since they won't be called. This reduces the size of the binary a bit. Change 109795730 Allow casts to and from int8 Change 109791914 Python 3 fix: filter has no len gradients.py calls len on the output of filter. A call to tuple is needed in between. Not sure why this wasn't caught when we ran the Python 3 tests. If I break it for Python 2 several tests break. Change 109757009 Fix minor grammatical errors in about.html The missing article needs no justification, I think. has -> have, because subjects are 'usability and functionality', not 'TensorFlow'. and also -> and, because 'also' is superfluous in this use. Change 109756627 TensorFlow: some doc updates to models/ files Change 109743899 TensorFlow: remove one more clang warning (class / struct inconsistency). Change 109741933 Document default for max_images in tf.image_summary It used to say max_images=None which hid the C++ defalut of 3. Now it says max_images=3. Fixes https://github.com/tensorflow/tensorflow/issues/441. It's unfortunate that an edit-distance-5 change produces such a large CL. Change 109741569 Update generated Op docs. Change 109739599 Renaming the Python variables in the layer weights of the fully connected MNIST model so that the variable and the TensorFlow names are different. This allows the documentation to be more explicit about the distinction between the weights and biases of different layers. Also, the documentation gets to describe the whether the TF name or the Python name is being used. Base CL: 109851372
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/framework/device_base.h2
-rw-r--r--tensorflow/core/kernels/cast_op.cc5
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/scatter_op.cc43
-rw-r--r--tensorflow/core/platform/default/mutex.h4
-rw-r--r--tensorflow/g3doc/api_docs/cc/ClassTensorShape.md14
-rw-r--r--tensorflow/g3doc/api_docs/python/client.md12
-rw-r--r--tensorflow/g3doc/api_docs/python/constant_op.md70
-rw-r--r--tensorflow/g3doc/api_docs/python/framework.md10
-rw-r--r--tensorflow/g3doc/api_docs/python/image.md26
-rw-r--r--tensorflow/g3doc/api_docs/python/index.md2
-rw-r--r--tensorflow/g3doc/api_docs/python/io_ops.md4
-rw-r--r--tensorflow/g3doc/api_docs/python/nn.md38
-rw-r--r--tensorflow/g3doc/api_docs/python/state_ops.md102
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md8
-rw-r--r--tensorflow/g3doc/get_started/basic_usage.md14
-rw-r--r--tensorflow/g3doc/get_started/index.md2
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md48
-rw-r--r--tensorflow/g3doc/how_tos/reading_data/index.md4
-rw-r--r--tensorflow/g3doc/how_tos/variables/index.md4
-rw-r--r--tensorflow/g3doc/tutorials/mnist/beginners/index.md2
-rw-r--r--tensorflow/g3doc/tutorials/mnist/pros/index.md8
-rw-r--r--tensorflow/g3doc/tutorials/mnist/tf/index.md4
-rw-r--r--tensorflow/models/embedding/README.md4
-rw-r--r--tensorflow/models/rnn/ptb/ptb_word_lm.py19
-rw-r--r--tensorflow/opensource_only/__init__.py0
-rw-r--r--tensorflow/python/BUILD2
-rw-r--r--tensorflow/python/client/session.py12
-rw-r--r--tensorflow/python/framework/ops.py14
-rw-r--r--tensorflow/python/framework/random_seed.py60
-rw-r--r--tensorflow/python/kernel_tests/template_test.py213
-rw-r--r--tensorflow/python/ops/array_ops.py4
-rw-r--r--tensorflow/python/ops/constant_op.py10
-rw-r--r--tensorflow/python/ops/gradients.py5
-rw-r--r--tensorflow/python/ops/image_ops.py2
-rw-r--r--tensorflow/python/ops/image_ops_test.py8
-rw-r--r--tensorflow/python/ops/standard_ops.py1
-rw-r--r--tensorflow/python/ops/state_ops.py1
-rw-r--r--tensorflow/python/ops/summary_ops.py2
-rw-r--r--tensorflow/python/ops/template.py215
-rw-r--r--tensorflow/python/ops/variables.py4
-rw-r--r--tensorflow/python/summary/event_accumulator.py2
-rw-r--r--tensorflow/python/summary/event_accumulator_test.py9
-rw-r--r--tensorflow/python/summary/impl/event_file_loader.py1
-rw-r--r--tensorflow/python/training/summary_io.py4
-rw-r--r--tensorflow/python/training/training_util.py2
-rw-r--r--tensorflow/stream_executor/dnn.cc2
-rw-r--r--tensorflow/tensorboard/CHANGES6
-rw-r--r--tensorflow/tensorboard/TAG2
49 files changed, 817 insertions, 215 deletions
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 66f181d4cf..99f09248b5 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/public/tensor.h"
namespace Eigen {
-class ThreadPoolDevice;
+struct ThreadPoolDevice;
} // end namespace Eigen
namespace perftools {
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 8d5ed3c2fe..5bc65997de 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -58,6 +58,7 @@ struct CastFunctor<CPUDevice, O, I> {
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
FN(arg0, uint8); \
+ FN(arg0, int8); \
FN(arg0, int16); \
FN(arg0, int32); \
FN(arg0, int64); \
@@ -67,6 +68,7 @@ struct CastFunctor<CPUDevice, O, I> {
#define CURRY_TYPES3(FN, arg0, arg1) \
FN(arg0, arg1, bool); \
FN(arg0, arg1, uint8); \
+ FN(arg0, arg1, int8); \
FN(arg0, arg1, int16); \
FN(arg0, arg1, int32); \
FN(arg0, arg1, int64); \
@@ -130,6 +132,7 @@ class CpuCastOp : public CastOpBase {
}
CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
+ CURRY_TYPES3(CAST_CASE, CPUDevice, int8);
CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
@@ -192,6 +195,7 @@ class GpuCastOp : public CastOpBase {
}
CURRY_TYPES3(CAST_CASE, GPUDevice, bool);
CURRY_TYPES3(CAST_CASE, GPUDevice, uint8);
+ CURRY_TYPES3(CAST_CASE, GPUDevice, int8);
CURRY_TYPES3(CAST_CASE, GPUDevice, int16);
CURRY_TYPES3(CAST_CASE, GPUDevice, int32);
CURRY_TYPES3(CAST_CASE, GPUDevice, int64);
@@ -217,6 +221,7 @@ REGISTER_KERNEL_BUILDER(Name("Cast").Device(DEVICE_CPU), CpuCastOp);
CURRY_TYPES2(REGISTER_CAST_GPU, bool);
CURRY_TYPES2(REGISTER_CAST_GPU, uint8);
+CURRY_TYPES2(REGISTER_CAST_GPU, int8);
CURRY_TYPES2(REGISTER_CAST_GPU, int16);
CURRY_TYPES2(REGISTER_CAST_GPU, int32);
CURRY_TYPES2(REGISTER_CAST_GPU, int64);
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 57f0873621..03958d1e37 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -37,6 +37,7 @@ struct CastFunctor<GPUDevice, O, I> {
#define DEFINE_ALL_FROM(in_type) \
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
+ DEFINE(in_type, int8); \
DEFINE(in_type, int16); \
DEFINE(in_type, int32); \
DEFINE(in_type, int64); \
@@ -45,6 +46,7 @@ struct CastFunctor<GPUDevice, O, I> {
DEFINE_ALL_FROM(bool);
DEFINE_ALL_FROM(uint8);
+DEFINE_ALL_FROM(int8);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 59315876aa..84dd625a9f 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -24,6 +24,30 @@ namespace tensorflow {
enum class UpdateOp { ASSIGN, ADD, SUB };
+template <UpdateOp Op>
+struct Assign {};
+template <>
+struct Assign<UpdateOp::ASSIGN> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p = u;
+ }
+};
+template <>
+struct Assign<UpdateOp::ADD> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p += u;
+ }
+};
+template <>
+struct Assign<UpdateOp::SUB> {
+ template <typename Params, typename Update>
+ static void Run(Params p, Update u) {
+ p -= u;
+ }
+};
+
template <class T, typename Index, UpdateOp op>
class ScatterUpdateOp : public OpKernel {
public:
@@ -105,23 +129,8 @@ class ScatterUpdateOp : public OpKernel {
for (Index i = 0; i < N; i++) {
// Copy last Ndim-1 dimensions of Tupdates[i] to
// Tparams[Tindices[i]]
- switch (op) {
- case UpdateOp::ASSIGN: {
- Tparams_flat.template chip<0>(Tindices_vec(i)) =
- Tupdates_flat.template chip<0>(i);
- break;
- }
- case UpdateOp::ADD: {
- Tparams_flat.template chip<0>(Tindices_vec(i)) +=
- Tupdates_flat.template chip<0>(i);
- break;
- }
- case UpdateOp::SUB: {
- Tparams_flat.template chip<0>(Tindices_vec(i)) -=
- Tupdates_flat.template chip<0>(i);
- break;
- }
- }
+ Assign<op>::Run(Tparams_flat.template chip<0>(Tindices_vec(i)),
+ Tupdates_flat.template chip<0>(i));
}
}
}
diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h
index be7558b7c1..5547f20e89 100644
--- a/tensorflow/core/platform/default/mutex.h
+++ b/tensorflow/core/platform/default/mutex.h
@@ -20,15 +20,13 @@ limitations under the License.
#include <condition_variable>
#include <mutex>
-#include "tensorflow/core/platform/default/thread_annotations.h"
-
namespace tensorflow {
enum LinkerInitialized { LINKER_INITIALIZED };
// A class that wraps around the std::mutex implementation, only adding an
// additional LinkerInitialized constructor interface.
-class LOCKABLE mutex : public std::mutex {
+class mutex : public std::mutex {
public:
mutex() {}
// The default implementation of std::mutex is safe to use after the linker
diff --git a/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md b/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md
index 2850da0cca..83a473f418 100644
--- a/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md
+++ b/tensorflow/g3doc/api_docs/cc/ClassTensorShape.md
@@ -47,6 +47,8 @@ Manages the dimensions of a Tensor and their sizes.
* [`string tensorflow::TensorShape::ShortDebugString() const`](#string_tensorflow_TensorShape_ShortDebugString)
* [`static bool tensorflow::TensorShape::IsValid(const TensorShapeProto &proto)`](#static_bool_tensorflow_TensorShape_IsValid)
* Returns `true` iff `proto` is a valid tensor shape.
+* [`static Status tensorflow::TensorShape::IsValidShape(const TensorShapeProto &proto)`](#static_Status_tensorflow_TensorShape_IsValidShape)
+* [`static string tensorflow::TensorShape::ShortDebugString(const TensorShapeProto &proto)`](#static_string_tensorflow_TensorShape_ShortDebugString)
##Member Details
@@ -193,3 +195,15 @@ For error messages.
Returns `true` iff `proto` is a valid tensor shape.
+
+#### `static Status tensorflow::TensorShape::IsValidShape(const TensorShapeProto &proto)` {#static_Status_tensorflow_TensorShape_IsValidShape}
+
+
+
+Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error status otherwise.
+
+#### `static string tensorflow::TensorShape::ShortDebugString(const TensorShapeProto &proto)` {#static_string_tensorflow_TensorShape_ShortDebugString}
+
+
+
+Same as `TensorShape(proto).ShortDebugString()` but doesn&apos;t crash for invalid protos.
diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md
index 3b8f1ceac7..4f243e58b3 100644
--- a/tensorflow/g3doc/api_docs/python/client.md
+++ b/tensorflow/g3doc/api_docs/python/client.md
@@ -30,7 +30,7 @@ c = a * b
sess = tf.Session()
# Evaluate the tensor `c`.
-print sess.run(c)
+print(sess.run(c))
```
A session may own resources, such as
@@ -196,7 +196,7 @@ sess = tf.Session()
with sess.as_default():
assert tf.get_default_session() is sess
- print c.eval()
+ print(c.eval())
```
To get the current default session, use
@@ -211,10 +211,10 @@ explicitly.
c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
- print c.eval()
+ print(c.eval())
# ...
with sess.as_default():
- print c.eval()
+ print(c.eval())
sess.close()
```
@@ -258,7 +258,7 @@ a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
# We can just use 'c.eval()' without passing 'sess'
-print c.eval()
+print(c.eval())
sess.close()
```
@@ -272,7 +272,7 @@ b = tf.constant(6.0)
c = a * b
with tf.Session():
# We can also use 'c.eval()' here.
- print c.eval()
+ print(c.eval())
```
- - -
diff --git a/tensorflow/g3doc/api_docs/python/constant_op.md b/tensorflow/g3doc/api_docs/python/constant_op.md
index 3aec8c1136..4018a1533b 100644
--- a/tensorflow/g3doc/api_docs/python/constant_op.md
+++ b/tensorflow/g3doc/api_docs/python/constant_op.md
@@ -319,15 +319,15 @@ shuff = tf.random_shuffle(c)
# Each time we run these ops, different results are generated
sess = tf.Session()
-print sess.run(norm)
-print sess.run(norm)
+print(sess.run(norm))
+print(sess.run(norm))
# Set an op-level seed to generate repeatable sequences across sessions.
c = tf.constant([[1, 2], [3, 4], [5, 6]])
sess = tf.Session()
norm = tf.random_normal(c, seed=1234)
-print sess.run(norm)
-print sess.run(norm)
+print(sess.run(norm))
+print(sess.run(norm))
```
Another common use of random values is the initialization of variables. Also see
@@ -341,7 +341,7 @@ init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
-print sess.run(var)
+print(sess.run(var))
```
- - -
@@ -507,19 +507,19 @@ graph-level nor op-level seeds:
a = tf.random_uniform([1])
b = tf.random_normal([1])
-print "Session 1"
+print("Session 1")
with tf.Session() as sess1:
- print sess1.run(a) # generates 'A1'
- print sess1.run(a) # generates 'A2'
- print sess1.run(b) # generates 'B1'
- print sess1.run(b) # generates 'B2'
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
-print "Session 2"
+print("Session 2")
with tf.Session() as sess2:
- print sess2.run(a) # generates 'A3'
- print sess2.run(a) # generates 'A4'
- print sess2.run(b) # generates 'B3'
- print sess2.run(b) # generates 'B4'
+ print(sess2.run(a)) # generates 'A3'
+ print(sess2.run(a)) # generates 'A4'
+ print(sess2.run(b)) # generates 'B3'
+ print(sess2.run(b)) # generates 'B4'
```
To generate the same repeatable sequence for an op across sessions, set the
@@ -531,19 +531,19 @@ b = tf.random_normal([1])
# Repeatedly running this block with the same graph will generate the same
# sequence of values for 'a', but different sequences of values for 'b'.
-print "Session 1"
+print("Session 1")
with tf.Session() as sess1:
- print sess1.run(a) # generates 'A1'
- print sess1.run(a) # generates 'A2'
- print sess1.run(b) # generates 'B1'
- print sess1.run(b) # generates 'B2'
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
-print "Session 2"
+print("Session 2")
with tf.Session() as sess2:
- print sess2.run(a) # generates 'A1'
- print sess2.run(a) # generates 'A2'
- print sess2.run(b) # generates 'B3'
- print sess2.run(b) # generates 'B4'
+ print(sess2.run(a)) # generates 'A1'
+ print(sess2.run(a)) # generates 'A2'
+ print(sess2.run(b)) # generates 'B3'
+ print(sess2.run(b)) # generates 'B4'
```
To make the random sequences generated by all ops be repeatable across
@@ -556,19 +556,19 @@ b = tf.random_normal([1])
# Repeatedly running this block with the same graph will generate different
# sequences of 'a' and 'b'.
-print "Session 1"
+print("Session 1")
with tf.Session() as sess1:
- print sess1.run(a) # generates 'A1'
- print sess1.run(a) # generates 'A2'
- print sess1.run(b) # generates 'B1'
- print sess1.run(b) # generates 'B2'
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
-print "Session 2"
+print("Session 2")
with tf.Session() as sess2:
- print sess2.run(a) # generates 'A1'
- print sess2.run(a) # generates 'A2'
- print sess2.run(b) # generates 'B1'
- print sess2.run(b) # generates 'B2'
+ print(sess2.run(a)) # generates 'A1'
+ print(sess2.run(a)) # generates 'A2'
+ print(sess2.run(b)) # generates 'B1'
+ print(sess2.run(b)) # generates 'B2'
```
##### Args:
diff --git a/tensorflow/g3doc/api_docs/python/framework.md b/tensorflow/g3doc/api_docs/python/framework.md
index a075fe367e..115fd3e956 100644
--- a/tensorflow/g3doc/api_docs/python/framework.md
+++ b/tensorflow/g3doc/api_docs/python/framework.md
@@ -1017,12 +1017,12 @@ example:
```python
c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
-print c.get_shape()
+print(c.get_shape())
==> TensorShape([Dimension(2), Dimension(3)])
d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
-print d.get_shape()
+print(d.get_shape())
==> TensorShape([Dimension(4), Dimension(2)])
# Raises a ValueError, because `c` and `d` do not have compatible
@@ -1031,7 +1031,7 @@ e = tf.matmul(c, d)
f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
-print f.get_shape()
+print(f.get_shape())
==> TensorShape([Dimension(3), Dimension(4)])
```
@@ -1063,12 +1063,12 @@ image = tf.image.decode_png(image_data, channels=3)
# The height and width dimensions of `image` are data dependent, and
# cannot be computed without executing the op.
-print image.get_shape()
+print(image.get_shape())
==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
# We know that each image in this dataset is 28 x 28 pixels.
image.set_shape([28, 28, 3])
-print image.get_shape()
+print(image.get_shape())
==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
```
diff --git a/tensorflow/g3doc/api_docs/python/image.md b/tensorflow/g3doc/api_docs/python/image.md
index cfa255e083..1cf61f4365 100644
--- a/tensorflow/g3doc/api_docs/python/image.md
+++ b/tensorflow/g3doc/api_docs/python/image.md
@@ -325,8 +325,6 @@ Input images can be of different types but output images are always float.
Resize `images` to `size` using nearest neighbor interpolation.
-Input images can be of different types but output images are always float.
-
##### Args:
@@ -1107,27 +1105,3 @@ Note that this implementation is limited:
* <b>`ValueError`</b>: if the shape of 'image' is incompatible with this function.
-
-## Other Functions and Classes
-- - -
-
-### `tf.image.resize_nearest_neighbor_grad(grads, size, name=None)` {#resize_nearest_neighbor_grad}
-
-Computes the gradient of nearest neighbor interpolation.
-
-##### Args:
-
-
-* <b>`grads`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int32`, `float32`, `float64`.
- 4-D with shape `[batch, height, width, channels]`.
-* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `orig_height, orig_width`. The
- original input size.
-* <b>`name`</b>: A name for the operation (optional).
-
-##### Returns:
-
- A `Tensor`. Has the same type as `grads`.
- 4-D with shape `[batch, orig_height, orig_width, channels]`. Gradients
- with respect to the input image.
-
-
diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md
index 4a91bad1a7..ed7ef413e9 100644
--- a/tensorflow/g3doc/api_docs/python/index.md
+++ b/tensorflow/g3doc/api_docs/python/index.md
@@ -57,6 +57,7 @@
* [`initialize_all_variables`](../../api_docs/python/state_ops.md#initialize_all_variables)
* [`initialize_variables`](../../api_docs/python/state_ops.md#initialize_variables)
* [`latest_checkpoint`](../../api_docs/python/state_ops.md#latest_checkpoint)
+ * [`make_template`](../../api_docs/python/state_ops.md#make_template)
* [`random_normal_initializer`](../../api_docs/python/state_ops.md#random_normal_initializer)
* [`random_uniform_initializer`](../../api_docs/python/state_ops.md#random_uniform_initializer)
* [`Saver`](../../api_docs/python/state_ops.md#Saver)
@@ -226,7 +227,6 @@
* [`resize_image_with_crop_or_pad`](../../api_docs/python/image.md#resize_image_with_crop_or_pad)
* [`resize_images`](../../api_docs/python/image.md#resize_images)
* [`resize_nearest_neighbor`](../../api_docs/python/image.md#resize_nearest_neighbor)
- * [`resize_nearest_neighbor_grad`](../../api_docs/python/image.md#resize_nearest_neighbor_grad)
* [`rgb_to_grayscale`](../../api_docs/python/image.md#rgb_to_grayscale)
* [`rgb_to_hsv`](../../api_docs/python/image.md#rgb_to_hsv)
* [`transpose_image`](../../api_docs/python/image.md#transpose_image)
diff --git a/tensorflow/g3doc/api_docs/python/io_ops.md b/tensorflow/g3doc/api_docs/python/io_ops.md
index 0b694a55c4..3ec4ab036b 100644
--- a/tensorflow/g3doc/api_docs/python/io_ops.md
+++ b/tensorflow/g3doc/api_docs/python/io_ops.md
@@ -30,10 +30,10 @@ x = tf.placeholder(tf.float32, shape=(1024, 1024))
y = tf.matmul(x, x)
with tf.Session() as sess:
- print sess.run(y) # ERROR: will fail because x was not fed.
+ print(sess.run(y)) # ERROR: will fail because x was not fed.
rand_array = np.random.rand(1024, 1024)
- print sess.run(y, feed_dict={x: rand_array}) # Will succeed.
+ print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.
```
##### Args:
diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md
index 67d4afca07..0e1933ad72 100644
--- a/tensorflow/g3doc/api_docs/python/nn.md
+++ b/tensorflow/g3doc/api_docs/python/nn.md
@@ -737,7 +737,7 @@ tensors.
- - -
-### `tf.nn.embedding_lookup(params, ids, name=None)` {#embedding_lookup}
+### `tf.nn.embedding_lookup(params, ids, partition_strategy='mod', name=None)` {#embedding_lookup}
Looks up `ids` in a list of embedding tensors.
@@ -747,18 +747,34 @@ tensors in `params`. It is a generalization of
interpreted as a partition of a larger embedding tensor.
If `len(params) > 1`, each element `id` of `ids` is partitioned between
-the elements of `params` by computing `p = id % len(params)`, and is
-then used to look up the slice `params[p][id // len(params), ...]`.
+the elements of `params` according to the `partition_strategy`.
+In all strategies, if the id space does not evenly divide the number of
+partitions, each of the first `(max_id + 1) % len(params)` partitions will
+be assigned one more id.
-The results of the lookup are then concatenated into a dense
+If `partition_strategy` is `"mod"`, we assign each id to partition
+`p = id % len(params)`. For instance,
+13 ids are split across 5 partitions as:
+`[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]`
+
+If `partition_strategy` is `"div"`, we assign ids to partitions in a
+contiguous manner. In this case, 13 ids are split across 5 partitions as:
+`[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`
+
+The results of the lookup are concatenated into a dense
tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
##### Args:
-* <b>`params`</b>: A list of tensors with the same shape and type.
+* <b>`params`</b>: A list of tensors with the same type and which can be concatenated
+ along dimension 0. Each `Tensor` must be appropriately sized for the given
+ `partition_strategy`.
* <b>`ids`</b>: A `Tensor` with type `int32` or `int64` containing the ids to be looked
up in `params`.
+* <b>`partition_strategy`</b>: A string specifying the partitioning strategy, relevant
+ if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
+ is `"mod"`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
@@ -864,7 +880,7 @@ TensorFlow provides the following sampled loss functions for faster training.
- - -
-### `tf.nn.nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=False, name='nce_loss')` {#nce_loss}
+### `tf.nn.nce_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=False, partition_strategy='mod', name='nce_loss')` {#nce_loss}
Computes and returns the noise-contrastive estimation training loss.
@@ -889,7 +905,7 @@ with an otherwise unused class.
* <b>`weights`</b>: A `Tensor` of shape `[num_classes, dim]`, or a list of `Tensor`
objects whose concatenation along dimension 0 has shape
- [num_classes, dim]. The (possibly-sharded) class embeddings.
+ [num_classes, dim]. The (possibly-partitioned) class embeddings.
* <b>`biases`</b>: A `Tensor` of shape `[num_classes]`. The class biases.
* <b>`inputs`</b>: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
@@ -908,6 +924,9 @@ with an otherwise unused class.
our [Candidate Sampling Algorithms Reference]
(../../extras/candidate_sampling.pdf).
Default is False.
+* <b>`partition_strategy`</b>: A string specifying the partitioning strategy, relevant
+ if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
+ Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
@@ -917,7 +936,7 @@ with an otherwise unused class.
- - -
-### `tf.nn.sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=True, name='sampled_softmax_loss')` {#sampled_softmax_loss}
+### `tf.nn.sampled_softmax_loss(weights, biases, inputs, labels, num_sampled, num_classes, num_true=1, sampled_values=None, remove_accidental_hits=True, partition_strategy='mod', name='sampled_softmax_loss')` {#sampled_softmax_loss}
Computes and returns the sampled softmax training loss.
@@ -956,6 +975,9 @@ Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.
* <b>`remove_accidental_hits`</b>: A `bool`. whether to remove "accidental hits"
where a sampled class equals one of the target classes. Default is
True.
+* <b>`partition_strategy`</b>: A string specifying the partitioning strategy, relevant
+ if `len(weights) > 1`. Currently `"div"` and `"mod"` are supported.
+ Default is `"mod"`. See `tf.nn.embedding_lookup` for more details.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
diff --git a/tensorflow/g3doc/api_docs/python/state_ops.md b/tensorflow/g3doc/api_docs/python/state_ops.md
index 27e42107fb..a599afe615 100644
--- a/tensorflow/g3doc/api_docs/python/state_ops.md
+++ b/tensorflow/g3doc/api_docs/python/state_ops.md
@@ -306,10 +306,10 @@ init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
# Usage passing the session explicitly.
- print v.eval(sess)
+ print(v.eval(sess))
# Usage with the default session. The 'with' block
# above makes 'sess' the default session.
- print v.eval()
+ print(v.eval())
```
##### Args:
@@ -915,6 +915,104 @@ Returns the current variable scope.
- - -
+### `tf.make_template(name_, func_, **kwargs)` {#make_template}
+
+Given an arbitrary function, wrap it so that it does variable sharing.
+
+This wraps `func_` in a Template and partially evaluates it. Templates are
+functions that create variables the first time they are called and reuse them
+thereafter. In order for `func_` to be compatible with a `Template` it must
+have the following properties:
+
+* The function should create all trainable variables and any variables that
+ should be reused by calling `tf.get_variable`. If a trainable variable is
+ created using `tf.Variable`, then a ValueError will be thrown. Variables
+ that are intended to be locals can be created by specifying
+ `tf.Variable(..., trainable=false)`.
+* The function may use variable scopes and other templates internally to
+ create and reuse variables, but it shouldn't use `tf.get_variables` to
+ capture variables that are defined outside of the scope of the function.
+* Internal scopes and variable names should not depend on any arguments that
+ are not supplied to `make_template`. In general you will get a ValueError
+ telling you that you are trying to reuse a variable that doesn't exist
+ if you make a mistake.
+
+In the following example, both `z` and `w` will be scaled by the same `y`. It
+is important to note that if we didn't assign `scalar_name` and used a
+different name for z and w that a `ValueError` would be thrown because it
+couldn't reuse the variable.
+
+```python
+def my_op(x, scalar_name):
+ var1 = tf.get_variable(scalar_name,
+ shape=[],
+ initializer=tf.constant_initializer(1))
+ return x * var1
+
+scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
+
+z = scale_by_y(input1)
+w = scale_by_y(input2)
+```
+
+As a safe-guard, the returned function will raise a `ValueError` after the
+first call if trainable variables are created by calling `tf.Variable`.
+
+If all of these are true, then 2 properties are enforced by the template:
+
+1. Calling the same template multiple times will share all non-local
+ variables.
+2. Two different templates are guaranteed to be unique, unless you reenter the
+ same variable scope as the initial definition of a template and redefine
+ it. An examples of this exception:
+
+```python
+def my_op(x, scalar_name):
+ var1 = tf.get_variable(scalar_name,
+ shape=[],
+ initializer=tf.constant_initializer(1))
+ return x * var1
+
+with tf.variable_scope('scope') as vs:
+ scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
+ z = scale_by_y(input1)
+ w = scale_by_y(input2)
+
+# Creates a template that reuses the variables above.
+with tf.variable_scope(vs, reuse=True):
+ scale_by_y2 = tf.make_template('scale_by_y', my_op, scalar_name='y')
+ z2 = scale_by_y2(input1)
+ w2 = scale_by_y2(input2)
+```
+
+Note: The full variable scope is captured at the time of the first call.
+
+Note: `name_` and `func_` have a following underscore to reduce the likelihood
+of collisions with kwargs.
+
+##### Args:
+
+
+* <b>`name_`</b>: A name for the scope created by this template. If necessary, the name
+ will be made unique by appending `_N` to the name.
+* <b>`func_`</b>: The function to wrap.
+* <b>`**kwargs`</b>: Keyword arguments to apply to `func_`.
+
+##### Returns:
+
+ A function that will enter a `variable_scope` before calling `func_`. The
+ first time it is called, it will create a non-reusing scope so that the
+ variables will be unique. On each subsequent call, it will reuse those
+ variables.
+
+##### Raises:
+
+
+* <b>`ValueError`</b>: if the name is None.
+
+
+- - -
+
### `tf.variable_op_scope(values, name, default_name, initializer=None)` {#variable_op_scope}
Returns a context manager for defining an op that creates variables.
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index fd6c353bac..29090f3807 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -1400,7 +1400,7 @@ summary has a summary value for each tag-value pair in `tags` and `values`.
- - -
-### `tf.image_summary(tag, tensor, max_images=None, collections=None, name=None)` {#image_summary}
+### `tf.image_summary(tag, tensor, max_images=3, collections=None, name=None)` {#image_summary}
Outputs a `Summary` protocol buffer with images.
@@ -1701,7 +1701,7 @@ Example: Print the contents of an events file.
```python
for e in tf.summary_iterator(path to events file):
- print e
+ print(e)
```
Example: Print selected summary values.
@@ -1714,7 +1714,7 @@ Example: Print selected summary values.
for e in tf.summary_iterator(path to events file):
for v in e.summary.value:
if v.tag == 'loss':
- print v.simple_value
+ print(v.simple_value)
```
See the protocol buffer definitions of
@@ -1749,7 +1749,7 @@ global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
sess = tf.Session()
# Initializes the variable.
sess.run(global_step_tensor.initializer)
-print 'global_step:', tf.train.global_step(sess, global_step_tensor)
+print('global_step:', tf.train.global_step(sess, global_step_tensor))
global_step: 10
```
diff --git a/tensorflow/g3doc/get_started/basic_usage.md b/tensorflow/g3doc/get_started/basic_usage.md
index cca15c1de4..66541966cd 100644
--- a/tensorflow/g3doc/get_started/basic_usage.md
+++ b/tensorflow/g3doc/get_started/basic_usage.md
@@ -101,7 +101,7 @@ sess = tf.Session()
#
# The output of the op is returned in 'result' as a numpy `ndarray` object.
result = sess.run(product)
-print result
+print(result)
# ==> [[ 12.]]
# Close the Session when we're done.
@@ -115,7 +115,7 @@ with a "with" block. The `Session` closes automatically at the end of the
```python
with tf.Session() as sess:
result = sess.run([product])
- print result
+ print(result)
```
The TensorFlow implementation translates the graph definition into executable
@@ -173,7 +173,7 @@ x.initializer.run()
# Add an op to subtract 'a' from 'x'. Run it and print the result
sub = tf.sub(x, a)
-print sub.eval()
+print(sub.eval())
# ==> [-2. -1.]
```
@@ -212,11 +212,11 @@ with tf.Session() as sess:
# Run the 'init' op
sess.run(init_op)
# Print the initial value of 'state'
- print sess.run(state)
+ print(sess.run(state))
# Run the op that updates 'state' and print 'state'.
for _ in range(3):
sess.run(update)
- print sess.run(state)
+ print(sess.run(state))
# output:
@@ -251,7 +251,7 @@ mul = tf.mul(input1, intermed)
with tf.Session() as sess:
result = sess.run([mul, intermed])
- print result
+ print(result)
# output:
# [array([ 21.], dtype=float32), array([ 7.], dtype=float32)]
@@ -279,7 +279,7 @@ input2 = tf.placeholder(tf.float32)
output = tf.mul(input1, input2)
with tf.Session() as sess:
- print sess.run([output], feed_dict={input1:[7.], input2:[2.]})
+ print(sess.run([output], feed_dict={input1:[7.], input2:[2.]}))
# output:
# [array([ 14.], dtype=float32)]
diff --git a/tensorflow/g3doc/get_started/index.md b/tensorflow/g3doc/get_started/index.md
index c00954000e..ffd4d1447c 100644
--- a/tensorflow/g3doc/get_started/index.md
+++ b/tensorflow/g3doc/get_started/index.md
@@ -40,7 +40,7 @@ sess.run(init)
for step in xrange(201):
sess.run(train)
if step % 20 == 0:
- print step, sess.run(W), sess.run(b)
+ print(step, sess.run(W), sess.run(b))
# Learns best fit is W: [0.1], b: [0.3]
```
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index 65016c54b5..effb66a708 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -6,7 +6,7 @@ github source.
## Requirements
The TensorFlow Python API currently supports Python 2.7 and Python 3.3+ from
-source. We are preparing Python 3 pip packages to go with the 0.6.0 release.
+source.
The GPU version (Linux only) currently requires the Cuda Toolkit 7.0 and CUDNN
6.5 V2. Please see [Cuda installation](#install_cuda).
@@ -39,7 +39,7 @@ Python.
The packages that will be installed or upgraded during the pip install are listed in the
[REQUIRED_PACKAGES section of setup.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/pip_package/setup.py)
-Install pip if it is not already installed:
+Install pip (or pip3 for python3) if it is not already installed:
```bash
# Ubuntu/Linux 64-bit
@@ -53,16 +53,31 @@ Install TensorFlow:
```bash
# Ubuntu/Linux 64-bit, CPU only:
-$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
+$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.6.0-cp27-none-linux_x86_64.whl
# Ubuntu/Linux 64-bit, GPU enabled:
-$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.5.0-cp27-none-linux_x86_64.whl
+$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27-none-linux_x86_64.whl
# Mac OS X, CPU only:
$ sudo easy_install --upgrade six
-$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
+$ sudo pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.6.0-py2-none-any.whl
```
+For python3:
+
+```bash
+# Ubuntu/Linux 64-bit, CPU only:
+$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.6.0-cp34-cp34m-linux_x86_64.whl
+
+# Ubuntu/Linux 64-bit, GPU enabled:
+$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp34-cp34m-linux_x86_64.whl
+
+# Mac OS X, CPU only:
+$ sudo easy_install --upgrade six
+$ sudo pip3 install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.6.0-py3-none-any.whl
+```
+
+
You can now [test your installation](#test_install).
## Virtualenv installation {#virtualenv_install}
@@ -115,6 +130,23 @@ $ source ~/tensorflow/bin/activate.csh # If using csh
(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.5.0-py2-none-any.whl
```
+and again for python3:
+
+```bash
+$ source ~/tensorflow/bin/activate # If using bash
+$ source ~/tensorflow/bin/activate.csh # If using csh
+(tensorflow)$ # Your prompt should change
+
+# Ubuntu/Linux 64-bit, CPU only:
+(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.6.0-cp34-cp34m-linux_x86_64.whl
+
+# Ubuntu/Linux 64-bit, GPU enabled:
+(tensorflow)$ pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp34-cp34m-linux_x86_64.whl
+
+# Mac OS X, CPU only:
+(tensorflow)$ pip3 install --upgrade https://storage.googleapis.com/tensorflow/mac/tensorflow-0.6.0-py3-none-any.whl
+```
+
With the Virtualenv environment activated, you can now
[test your installation](#test_install).
@@ -206,11 +238,11 @@ $ python
>>> import tensorflow as tf
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
->>> print sess.run(hello)
+>>> print(sess.run(hello))
Hello, TensorFlow!
>>> a = tf.constant(10)
>>> b = tf.constant(32)
->>> print sess.run(a + b)
+>>> print(sess.run(a + b))
42
>>>
```
@@ -228,7 +260,7 @@ The exact location of the Python library depends on your system, but is usually
You can find out the directory with the following command:
```bash
-$ python -c 'import site; print "\n".join(site.getsitepackages())'
+$ python -c 'import site; print("\n".join(site.getsitepackages()))'
```
The simple demo model for classifying handwritten digits from the MNIST dataset
diff --git a/tensorflow/g3doc/how_tos/reading_data/index.md b/tensorflow/g3doc/how_tos/reading_data/index.md
index 7caff72c2e..ba1c4454b2 100644
--- a/tensorflow/g3doc/how_tos/reading_data/index.md
+++ b/tensorflow/g3doc/how_tos/reading_data/index.md
@@ -23,7 +23,7 @@ that initiates computation.
with tf.Session():
input = tf.placeholder(tf.float32)
classifier = ...
- print classifier.eval(feed_dict={input: my_python_preprocessing_fn()})
+ print(classifier.eval(feed_dict={input: my_python_preprocessing_fn()}))
```
While you can replace any Tensor with feed data, including variables and
@@ -287,7 +287,7 @@ try:
sess.run(train_op)
except tf.errors.OutOfRangeError:
- print 'Done training -- epoch limit reached'
+ print('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop.
coord.request_stop()
diff --git a/tensorflow/g3doc/how_tos/variables/index.md b/tensorflow/g3doc/how_tos/variables/index.md
index 6c3bce6c0b..9a80b9eb88 100644
--- a/tensorflow/g3doc/how_tos/variables/index.md
+++ b/tensorflow/g3doc/how_tos/variables/index.md
@@ -146,7 +146,7 @@ with tf.Session() as sess:
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
- print "Model saved in file: ", save_path
+ print("Model saved in file: ", save_path)
```
### Restoring Variables
@@ -167,7 +167,7 @@ saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
- print "Model restored."
+ print("Model restored.")
# Do some work with the model
...
```
diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
index 2d4d6c566c..1367e931b8 100644
--- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md
@@ -409,7 +409,7 @@ accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
Finally, we ask for our accuracy on our test data.
```python
-print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
+print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
```
This should be about 91%.
diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md
index 866d4c8367..0ffcafc879 100644
--- a/tensorflow/g3doc/tutorials/mnist/pros/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md
@@ -227,7 +227,7 @@ Finally, we can evaluate our accuracy on the test data. This should be about
91% correct.
```python
-print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
+print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
```
## Build a Multilayer Convolutional Network
@@ -380,11 +380,11 @@ for i in range(20000):
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={
x:batch[0], y_: batch[1], keep_prob: 1.0})
- print "step %d, training accuracy %g"%(i, train_accuracy)
+ print("step %d, training accuracy %g"%(i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
-print "test accuracy %g"%accuracy.eval(feed_dict={
- x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
+print("test accuracy %g"%accuracy.eval(feed_dict={
+ x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
```
The final test set accuracy after running this code should be approximately 99.2%.
diff --git a/tensorflow/g3doc/tutorials/mnist/tf/index.md b/tensorflow/g3doc/tutorials/mnist/tf/index.md
index 94418481a6..cb0abf6829 100644
--- a/tensorflow/g3doc/tutorials/mnist/tf/index.md
+++ b/tensorflow/g3doc/tutorials/mnist/tf/index.md
@@ -145,8 +145,8 @@ of units in the layer to which they connect.
The graph's three primary ops -- two [`tf.nn.relu`](../../../api_docs/python/nn.md#relu)
ops wrapping [`tf.matmul`](../../../api_docs/python/math_ops.md#matmul)
for the hidden layers and one extra `tf.matmul` for the logits -- are then
-created, each in turn, with their `tf.Variable` instances connected to the
-input placeholder or the output tensor of the layer beneath each.
+created, each in turn, with separate `tf.Variable` instances connected to each
+of the input placeholders or the output tensors of the previous layer.
```python
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
diff --git a/tensorflow/models/embedding/README.md b/tensorflow/models/embedding/README.md
index b89c2cb944..bec585661a 100644
--- a/tensorflow/models/embedding/README.md
+++ b/tensorflow/models/embedding/README.md
@@ -12,8 +12,8 @@ tutorials. Brief instructions are below.
To download the example text and evaluation data:
```shell
-wget http://mattmahoney.net/dc/text8.zip -O text8.gz
-gzip -d text8.gz -f
+wget http://mattmahoney.net/dc/text8.zip -O text8.zip
+unzip text8.zip
wget https://word2vec.googlecode.com/svn/trunk/questions-words.txt
```
diff --git a/tensorflow/models/rnn/ptb/ptb_word_lm.py b/tensorflow/models/rnn/ptb/ptb_word_lm.py
index 3380a4fc92..41b67d8b24 100644
--- a/tensorflow/models/rnn/ptb/ptb_word_lm.py
+++ b/tensorflow/models/rnn/ptb/ptb_word_lm.py
@@ -19,11 +19,6 @@ Trains the model described in:
(Zaremba, et. al.) Recurrent Neural Network Regularization
http://arxiv.org/abs/1409.2329
-The data required for this example is in the data/ dir of the
-PTB dataset from Tomas Mikolov's webpage:
-
-http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
-
There are 3 supported model configurations:
===========================================
| config | epochs | train | valid | test
@@ -46,13 +41,15 @@ The hyperparameters used in the model:
- lr_decay - the decay of the learning rate for each epoch after "max_epoch"
- batch_size - the batch size
-To compile on CPU:
- bazel build -c opt tensorflow/models/rnn/ptb:ptb_word_lm
-To compile on GPU:
- bazel build -c opt tensorflow --config=cuda \
- tensorflow/models/rnn/ptb:ptb_word_lm
+The data required for this example is in the data/ dir of the
+PTB dataset from Tomas Mikolov's webpage:
+
+$ wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
+$ tar xvf simple-examples.tgz
+
To run:
- ./bazel-bin/.../ptb_word_lm --data_path=/tmp/simple-examples/data/
+
+$ python ptb_word_lm.py --data_path=simple-examples/data/
"""
from __future__ import absolute_import
diff --git a/tensorflow/opensource_only/__init__.py b/tensorflow/opensource_only/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tensorflow/opensource_only/__init__.py
+++ /dev/null
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index f7e111d117..764f7a794d 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -638,6 +638,7 @@ py_library(
"ops/state_ops.py",
"ops/string_ops.py",
"ops/summary_ops.py",
+ "ops/template.py",
"ops/variable_scope.py",
"ops/variables.py",
"user_ops/user_ops.py",
@@ -895,6 +896,7 @@ cpu_only_kernel_test_list = glob([
"kernel_tests/sparse_to_dense_op_test.py",
"kernel_tests/sparsemask_op_test.py",
"kernel_tests/summary_ops_test.py",
+ "kernel_tests/template_test.py",
"kernel_tests/topk_op_test.py",
"kernel_tests/unique_op_test.py",
"kernel_tests/variable_scope_test.py",
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 8273d4f49d..4cc1b9809b 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -163,7 +163,7 @@ class BaseSession(SessionInterface):
with sess.as_default():
assert tf.get_default_session() is sess
- print c.eval()
+ print(c.eval())
```
To get the current default session, use
@@ -178,10 +178,10 @@ class BaseSession(SessionInterface):
c = tf.constant(...)
sess = tf.Session()
with sess.as_default():
- print c.eval()
+ print(c.eval())
# ...
with sess.as_default():
- print c.eval()
+ print(c.eval())
sess.close()
```
@@ -463,7 +463,7 @@ class Session(BaseSession):
sess = tf.Session()
# Evaluate the tensor `c`.
- print sess.run(c)
+ print(sess.run(c))
```
A session may own resources, such as
@@ -568,7 +568,7 @@ class InteractiveSession(BaseSession):
b = tf.constant(6.0)
c = a * b
# We can just use 'c.eval()' without passing 'sess'
- print c.eval()
+ print(c.eval())
sess.close()
```
@@ -582,7 +582,7 @@ class InteractiveSession(BaseSession):
c = a * b
with tf.Session():
# We can also use 'c.eval()' here.
- print c.eval()
+ print(c.eval())
```
@@__init__
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index f810e694b6..aa0ea4d1d3 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -269,12 +269,12 @@ class Tensor(object):
```python
c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
- print c.get_shape()
+ print(c.get_shape())
==> TensorShape([Dimension(2), Dimension(3)])
d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
- print d.get_shape()
+ print(d.get_shape())
==> TensorShape([Dimension(4), Dimension(2)])
# Raises a ValueError, because `c` and `d` do not have compatible
@@ -283,7 +283,7 @@ class Tensor(object):
f = tf.matmul(c, d, transpose_a=True, transpose_b=True)
- print f.get_shape()
+ print(f.get_shape())
==> TensorShape([Dimension(3), Dimension(4)])
```
@@ -312,12 +312,12 @@ class Tensor(object):
# The height and width dimensions of `image` are data dependent, and
# cannot be computed without executing the op.
- print image.get_shape()
+ print(image.get_shape())
==> TensorShape([Dimension(None), Dimension(None), Dimension(3)])
# We know that each image in this dataset is 28 x 28 pixels.
image.set_shape([28, 28, 3])
- print image.get_shape()
+ print(image.get_shape())
==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
```
@@ -1332,7 +1332,7 @@ class RegisterGradient(object):
"""A decorator for registering the gradient function for an op type.
This decorator is only used when defining a new op type. For an op
- with `m` inputs and `n` inputs, the gradient function is a function
+ with `m` inputs and `n` outputs, the gradient function is a function
that takes the original `Operation` and `n` `Tensor` objects
(representing the gradients with respect to each output of the op),
and returns `m` `Tensor` objects (representing the partial gradients
@@ -1345,7 +1345,7 @@ class RegisterGradient(object):
```python
@tf.RegisterGradient("Sub")
def _sub_grad(unused_op, grad):
- return grad, tf.Neg(grad)
+ return grad, tf.neg(grad)
```
The decorator argument `op_type` is the string type of an
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index dd0ef53c93..b70f626a9e 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -85,19 +85,19 @@ def set_random_seed(seed):
a = tf.random_uniform([1])
b = tf.random_normal([1])
- print "Session 1"
+ print("Session 1")
with tf.Session() as sess1:
- print sess1.run(a) # generates 'A1'
- print sess1.run(a) # generates 'A2'
- print sess1.run(b) # generates 'B1'
- print sess1.run(b) # generates 'B2'
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
- print "Session 2"
+ print("Session 2")
with tf.Session() as sess2:
- print sess2.run(a) # generates 'A3'
- print sess2.run(a) # generates 'A4'
- print sess2.run(b) # generates 'B3'
- print sess2.run(b) # generates 'B4'
+ print(sess2.run(a)) # generates 'A3'
+ print(sess2.run(a)) # generates 'A4'
+ print(sess2.run(b)) # generates 'B3'
+ print(sess2.run(b)) # generates 'B4'
```
To generate the same repeatable sequence for an op across sessions, set the
@@ -109,19 +109,19 @@ def set_random_seed(seed):
# Repeatedly running this block with the same graph will generate the same
# sequence of values for 'a', but different sequences of values for 'b'.
- print "Session 1"
+ print("Session 1")
with tf.Session() as sess1:
- print sess1.run(a) # generates 'A1'
- print sess1.run(a) # generates 'A2'
- print sess1.run(b) # generates 'B1'
- print sess1.run(b) # generates 'B2'
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
- print "Session 2"
+ print("Session 2")
with tf.Session() as sess2:
- print sess2.run(a) # generates 'A1'
- print sess2.run(a) # generates 'A2'
- print sess2.run(b) # generates 'B3'
- print sess2.run(b) # generates 'B4'
+ print(sess2.run(a)) # generates 'A1'
+ print(sess2.run(a)) # generates 'A2'
+ print(sess2.run(b)) # generates 'B3'
+ print(sess2.run(b)) # generates 'B4'
```
To make the random sequences generated by all ops be repeatable across
@@ -134,19 +134,19 @@ def set_random_seed(seed):
# Repeatedly running this block with the same graph will generate different
# sequences of 'a' and 'b'.
- print "Session 1"
+ print("Session 1")
with tf.Session() as sess1:
- print sess1.run(a) # generates 'A1'
- print sess1.run(a) # generates 'A2'
- print sess1.run(b) # generates 'B1'
- print sess1.run(b) # generates 'B2'
+ print(sess1.run(a)) # generates 'A1'
+ print(sess1.run(a)) # generates 'A2'
+ print(sess1.run(b)) # generates 'B1'
+ print(sess1.run(b)) # generates 'B2'
- print "Session 2"
+ print("Session 2")
with tf.Session() as sess2:
- print sess2.run(a) # generates 'A1'
- print sess2.run(a) # generates 'A2'
- print sess2.run(b) # generates 'B1'
- print sess2.run(b) # generates 'B2'
+ print(sess2.run(a)) # generates 'A1'
+ print(sess2.run(a)) # generates 'A2'
+ print(sess2.run(b)) # generates 'B1'
+ print(sess2.run(b)) # generates 'B2'
```
Args:
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py
new file mode 100644
index 0000000000..794b012749
--- /dev/null
+++ b/tensorflow/python/kernel_tests/template_test.py
@@ -0,0 +1,213 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for make_template."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import traceback
+
+import tensorflow.python.platform
+
+import tensorflow as tf
+
+from tensorflow.python.ops import template
+
+
+def var_scoped_function():
+ return tf.get_variable("dummy",
+ shape=[1],
+ initializer=tf.zeros_initializer)
+
+
+def internally_var_scoped_function(scope_name):
+ with tf.variable_scope(scope_name):
+ return tf.get_variable("dummy",
+ shape=[1],
+ initializer=tf.zeros_initializer)
+
+
+def function_with_create(trainable):
+ """Creates a variable as a side effect using tf.Variable."""
+ tf.Variable(0, trainable=trainable)
+ return tf.get_variable("dummy",
+ shape=[1],
+ initializer=tf.zeros_initializer)
+
+
+class TemplateTest(tf.test.TestCase):
+
+ def test_end_to_end(self):
+ """This test shows a very simple line model with test_loss.
+
+ The template is used to share parameters between a training and test model.
+ """
+ # y = 2x + 1
+ training_input, training_output = ([1., 2., 3., 4.], [2.8, 5.1, 7.2, 8.7])
+ test_input, test_output = ([5., 6., 7., 8.], [11, 13, 15, 17])
+
+ tf.set_random_seed(1234)
+
+ def test_line(x):
+ m = tf.get_variable("w", shape=[],
+ initializer=tf.truncated_normal_initializer())
+ b = tf.get_variable("b", shape=[],
+ initializer=tf.truncated_normal_initializer())
+ return x * m + b
+
+ line_template = template.make_template("line", test_line)
+
+ train_prediction = line_template(training_input)
+ test_prediction = line_template(test_input)
+
+ train_loss = tf.reduce_mean(tf.square(train_prediction - training_output))
+ test_loss = tf.reduce_mean(tf.square(test_prediction - test_output))
+
+ optimizer = tf.train.GradientDescentOptimizer(0.1)
+ train_op = optimizer.minimize(train_loss)
+
+ with tf.Session() as sess:
+ sess.run(tf.initialize_all_variables())
+ initial_test_loss = sess.run(test_loss)
+ sess.run(train_op)
+ final_test_loss = sess.run(test_loss)
+
+ # Parameters are tied, so the loss should have gone down when we trained it.
+ self.assertLess(final_test_loss, initial_test_loss)
+
+ def test_skip_stack_frames(self):
+ first = traceback.format_stack()
+ second = traceback.format_stack()
+ result = template._skip_common_stack_elements(first, second)
+ self.assertEqual(1, len(result))
+ self.assertNotEqual(len(first), len(result))
+
+ def test_template_with_name(self):
+ tmpl1 = template.make_template("s1", var_scoped_function)
+ tmpl2 = template.make_template("s1", var_scoped_function)
+
+ v1 = tmpl1()
+ v2 = tmpl1()
+ v3 = tmpl2()
+ self.assertEqual(v1, v2)
+ self.assertNotEqual(v1, v3)
+ self.assertEqual("s1/dummy:0", v1.name)
+ self.assertEqual("s1_2/dummy:0", v3.name)
+
+ def test_template_in_scope(self):
+ tmpl1 = template.make_template("s1", var_scoped_function)
+ tmpl2 = template.make_template("s1", var_scoped_function)
+
+ with tf.variable_scope("scope"):
+ v1 = tmpl1()
+ v3 = tmpl2()
+
+ # The template contract requires the following to ignore scope2.
+ with tf.variable_scope("scope2"):
+ v2 = tmpl1()
+ self.assertEqual(v1, v2)
+ self.assertNotEqual(v1, v3)
+ self.assertEqual("scope/s1/dummy:0", v1.name)
+ self.assertEqual("scope/s1_1/dummy:0", v3.name)
+
+ def test_template_with_internal_reuse(self):
+ tmpl1 = template.make_template("s1", internally_var_scoped_function)
+ tmpl2 = template.make_template("s1", internally_var_scoped_function)
+
+ v1 = tmpl1("test")
+ v2 = tmpl1("test")
+ v3 = tmpl2("test")
+ self.assertEqual(v1, v2)
+ self.assertNotEqual(v1, v3)
+ self.assertEqual("s1/test/dummy:0", v1.name)
+ self.assertEqual("s1_2/test/dummy:0", v3.name)
+
+ with self.assertRaises(ValueError):
+ tmpl1("not_test")
+
+ def test_template_without_name(self):
+ with self.assertRaises(ValueError):
+ template.make_template(None, var_scoped_function)
+
+ def test_make_template(self):
+ # Test both that we can call it with positional and keywords.
+ tmpl1 = template.make_template(
+ "s1", internally_var_scoped_function, scope_name="test")
+ tmpl2 = template.make_template(
+ "s1", internally_var_scoped_function, scope_name="test")
+
+ v1 = tmpl1()
+ v2 = tmpl1()
+ v3 = tmpl2()
+ self.assertEqual(v1, v2)
+ self.assertNotEqual(v1, v3)
+ self.assertEqual("s1/test/dummy:0", v1.name)
+ self.assertEqual("s1_2/test/dummy:0", v3.name)
+
+ def test_enforces_no_extra_trainable_variables(self):
+ tmpl = template.make_template("s", function_with_create, trainable=True)
+
+ tmpl()
+ with self.assertRaises(ValueError):
+ tmpl()
+
+ def test_permits_extra_non_trainable_variables(self):
+ tmpl = template.make_template("s", function_with_create, trainable=False)
+ self.assertEqual(tmpl(), tmpl())
+
+ def test_internal_variable_reuse(self):
+ def nested():
+ with tf.variable_scope("nested") as vs:
+ v1 = tf.get_variable("x", initializer=tf.zeros_initializer, shape=[])
+ with tf.variable_scope(vs, reuse=True):
+ v2 = tf.get_variable("x")
+ self.assertEqual(v1, v2)
+ return v1
+
+ tmpl1 = template.make_template("s1", nested)
+ tmpl2 = template.make_template("s1", nested)
+
+ v1 = tmpl1()
+ v2 = tmpl1()
+ v3 = tmpl2()
+ self.assertEqual(v1, v2)
+ self.assertNotEqual(v1, v3)
+ self.assertEqual("s1/nested/x:0", v1.name)
+ self.assertEqual("s1_2/nested/x:0", v3.name)
+
+ def test_nested_templates(self):
+ def nested_template():
+ nested1 = template.make_template("nested", var_scoped_function)
+ nested2 = template.make_template("nested", var_scoped_function)
+ v1 = nested1()
+ v2 = nested2()
+ self.assertNotEqual(v1, v2)
+ return v2
+
+ tmpl1 = template.make_template("s1", nested_template)
+ tmpl2 = template.make_template("s1", nested_template)
+
+ v1 = tmpl1()
+ v2 = tmpl1()
+ v3 = tmpl2()
+ self.assertEqual(v1, v2)
+ self.assertNotEqual(v1, v3)
+ self.assertEqual("s1/nested_1/dummy:0", v1.name)
+ self.assertEqual("s1_2/nested_1/dummy:0", v3.name)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index bcbca6943a..613bdf49f0 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -646,10 +646,10 @@ def placeholder(dtype, shape=None, name=None):
y = tf.matmul(x, x)
with tf.Session() as sess:
- print sess.run(y) # ERROR: will fail because x was not fed.
+ print(sess.run(y)) # ERROR: will fail because x was not fed.
rand_array = np.random.rand(1024, 1024)
- print sess.run(y, feed_dict={x: rand_array}) # Will succeed.
+ print(sess.run(y, feed_dict={x: rand_array})) # Will succeed.
```
Args:
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py
index 5d8d8a88d0..a5f1a9c8ee 100644
--- a/tensorflow/python/ops/constant_op.py
+++ b/tensorflow/python/ops/constant_op.py
@@ -62,15 +62,15 @@ shuff = tf.random_shuffle(c)
# Each time we run these ops, different results are generated
sess = tf.Session()
-print sess.run(norm)
-print sess.run(norm)
+print(sess.run(norm))
+print(sess.run(norm))
# Set an op-level seed to generate repeatable sequences across sessions.
c = tf.constant([[1, 2], [3, 4], [5, 6]])
sess = tf.Session()
norm = tf.random_normal(c, seed=1234)
-print sess.run(norm)
-print sess.run(norm)
+print(sess.run(norm))
+print(sess.run(norm))
```
Another common use of random values is the initialization of variables. Also see
@@ -84,7 +84,7 @@ init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
-print sess.run(var)
+print(sess.run(var))
```
@@random_normal
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index f17211d677..4ee977d308 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -444,7 +444,7 @@ def gradients(ys,
op_wrapper = control_flow_ops.MakeWrapper(op)
in_grads = _AsList(grad_fn(op_wrapper, *out_grads))
_VerifyGeneratedGradients(in_grads, op)
- if gate_gradients and len(filter(None, in_grads)) > 1:
+ if gate_gradients and len(tuple(filter(None, in_grads))) > 1:
in_grads = control_flow_ops.tuple(in_grads)
logging.vlog(1, "Gradient for '" + op.name + "'")
logging.vlog(1, " in --> %s",
@@ -575,7 +575,8 @@ def _AggregatedGrads(grads, op, has_control_flow, aggregation_method=None):
if aggregation_method not in [AggregationMethod.ADD_N,
AggregationMethod.EXPERIMENTAL_TREE,
AggregationMethod.EXPERIMENTAL_ACCUMULATE_N]:
- raise ValueError("Invalid aggregation_method specified.")
+ raise ValueError(
+ "Invalid aggregation_method specified %s." % aggregation_method)
out_grads = _GetGrads(grads, op)
for i, out_grad in enumerate(out_grads):
if has_control_flow:
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 419ea4f981..8f4644dab5 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -166,6 +166,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import common_shapes
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import gen_image_ops
+from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -607,6 +608,7 @@ def per_image_whitening(image):
variance = (math_ops.reduce_mean(math_ops.square(image)) -
math_ops.square(image_mean))
+ variance = gen_nn_ops.relu(variance)
stddev = math_ops.sqrt(variance)
# Apply a minimum normalization that protects us against uniform images.
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index a801a4f73b..cb7c9b8361 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -489,6 +489,14 @@ class PerImageWhiteningTest(test_util.TensorFlowTestCase):
y_tf = y.eval()
self.assertAllClose(y_tf, y_np, atol=1e-4)
+ def testUniformImage(self):
+ im_np = np.ones([19, 19, 3]).astype(np.float32) * 249
+ im = constant_op.constant(im_np)
+ whiten = image_ops.per_image_whitening(im)
+ with self.test_session():
+ whiten_np = whiten.eval()
+ self.assertFalse(np.any(np.isnan(whiten_np)))
+
class CropToBoundingBoxTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 623afbcf31..2075e3c913 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -56,5 +56,6 @@ from tensorflow.python.ops.summary_ops import image_summary
from tensorflow.python.ops.summary_ops import merge_all_summaries
from tensorflow.python.ops.summary_ops import merge_summary
from tensorflow.python.ops.summary_ops import scalar_summary
+from tensorflow.python.ops.template import *
from tensorflow.python.ops.variable_scope import *
from tensorflow.python.ops.variables import *
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index 25c679b80b..ef726c78e5 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -45,6 +45,7 @@ create variables contingent on certain conditions.
@@get_variable
@@get_variable_scope
+@@make_template
@@variable_op_scope
@@variable_scope
diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py
index ccf9ab191c..800ab7bc7e 100644
--- a/tensorflow/python/ops/summary_ops.py
+++ b/tensorflow/python/ops/summary_ops.py
@@ -60,7 +60,7 @@ def histogram_summary(tag, values, collections=None, name=None):
return val
-def image_summary(tag, tensor, max_images=None, collections=None, name=None):
+def image_summary(tag, tensor, max_images=3, collections=None, name=None):
"""Outputs a `Summary` protocol buffer with images.
The summary has up to `max_images` summary values containing images. The
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
new file mode 100644
index 0000000000..60cff4b97b
--- /dev/null
+++ b/tensorflow/python/ops/template.py
@@ -0,0 +1,215 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Provides templates which allow variable sharing."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import traceback
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import logging
+
+
+__all__ = ["make_template"]
+
+
+def make_template(name_, func_, **kwargs):
+ """Given an arbitrary function, wrap it so that it does variable sharing.
+
+ This wraps `func_` in a Template and partially evaluates it. Templates are
+ functions that create variables the first time they are called and reuse them
+ thereafter. In order for `func_` to be compatible with a `Template` it must
+ have the following properties:
+
+ * The function should create all trainable variables and any variables that
+ should be reused by calling `tf.get_variable`. If a trainable variable is
+ created using `tf.Variable`, then a ValueError will be thrown. Variables
+ that are intended to be locals can be created by specifying
+ `tf.Variable(..., trainable=false)`.
+ * The function may use variable scopes and other templates internally to
+ create and reuse variables, but it shouldn't use `tf.get_variables` to
+ capture variables that are defined outside of the scope of the function.
+ * Internal scopes and variable names should not depend on any arguments that
+ are not supplied to `make_template`. In general you will get a ValueError
+ telling you that you are trying to reuse a variable that doesn't exist
+ if you make a mistake.
+
+ In the following example, both `z` and `w` will be scaled by the same `y`. It
+ is important to note that if we didn't assign `scalar_name` and used a
+ different name for z and w that a `ValueError` would be thrown because it
+ couldn't reuse the variable.
+
+ ```python
+ def my_op(x, scalar_name):
+ var1 = tf.get_variable(scalar_name,
+ shape=[],
+ initializer=tf.constant_initializer(1))
+ return x * var1
+
+ scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
+
+ z = scale_by_y(input1)
+ w = scale_by_y(input2)
+ ```
+
+ As a safe-guard, the returned function will raise a `ValueError` after the
+ first call if trainable variables are created by calling `tf.Variable`.
+
+ If all of these are true, then 2 properties are enforced by the template:
+
+ 1. Calling the same template multiple times will share all non-local
+ variables.
+ 2. Two different templates are guaranteed to be unique, unless you reenter the
+ same variable scope as the initial definition of a template and redefine
+ it. An examples of this exception:
+
+ ```python
+ def my_op(x, scalar_name):
+ var1 = tf.get_variable(scalar_name,
+ shape=[],
+ initializer=tf.constant_initializer(1))
+ return x * var1
+
+ with tf.variable_scope('scope') as vs:
+ scale_by_y = tf.make_template('scale_by_y', my_op, scalar_name='y')
+ z = scale_by_y(input1)
+ w = scale_by_y(input2)
+
+ # Creates a template that reuses the variables above.
+ with tf.variable_scope(vs, reuse=True):
+ scale_by_y2 = tf.make_template('scale_by_y', my_op, scalar_name='y')
+ z2 = scale_by_y2(input1)
+ w2 = scale_by_y2(input2)
+ ```
+
+ Note: The full variable scope is captured at the time of the first call.
+
+ Note: `name_` and `func_` have a following underscore to reduce the likelihood
+ of collisions with kwargs.
+
+ Args:
+ name_: A name for the scope created by this template. If necessary, the name
+ will be made unique by appending `_N` to the name.
+ func_: The function to wrap.
+ **kwargs: Keyword arguments to apply to `func_`.
+
+ Returns:
+ A function that will enter a `variable_scope` before calling `func_`. The
+ first time it is called, it will create a non-reusing scope so that the
+ variables will be unique. On each subsequent call, it will reuse those
+ variables.
+
+ Raises:
+ ValueError: if the name is None.
+ """
+ if kwargs:
+ func_ = functools.partial(func_, **kwargs)
+ return Template(name_, func_)
+
+
+def _skip_common_stack_elements(stacktrace, base_case):
+ """Skips items that the target stacktrace shares with the base stacktrace."""
+ for i, (trace, base) in enumerate(zip(stacktrace, base_case)):
+ if trace != base:
+ return stacktrace[i:]
+ return stacktrace[-1:]
+
+
+class Template(object):
+ """Wrap a function to aid in variable sharing.
+
+ Templates are functions that create variables the first time they are called
+ and reuse them thereafter. See `make_template` for full documentation.
+
+ Note: The full variable scope is captured at the time of the first call.
+ """
+
+ def __init__(self, name, func):
+ """Creates a template for the given function.
+
+ Args:
+ name: A name for the scope created by this template. The
+ name will be made unique by appending `_N` to the it (see how
+ `tf.variable_op_scope` treats the `default_name` for details).
+ func: The function to apply each time.
+
+ Raises:
+ ValueError: if the name is None.
+ """
+ self._func = func
+ self._stacktrace = traceback.format_stack()[:-2]
+ self._name = name
+ if name is None:
+ raise ValueError("name cannot be None.")
+ self._var_scope = None
+
+ def _call_func(self, args, kwargs, check_for_new_variables):
+ try:
+ vars_at_start = len(ops.get_collection(ops.GraphKeys.VARIABLES))
+ trainable_at_start = len(
+ ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
+
+ result = self._func(*args, **kwargs)
+ if check_for_new_variables:
+ trainable_variables = ops.get_collection(
+ ops.GraphKeys.TRAINABLE_VARIABLES)
+ # If a variable that we intend to train is created as a side effect
+ # of creating a template, then that is almost certainly an error.
+ if trainable_at_start != len(trainable_variables):
+ raise ValueError("Trainable variable created when calling a template "
+ "after the first time, perhaps you used tf.Variable "
+ "when you meant tf.get_variable: %s" %
+ (trainable_variables[trainable_at_start:],))
+
+ # Non-trainable tracking variables are a legitimate reason why a new
+ # variable would be created, but it is a relatively advanced use-case,
+ # so log it.
+ variables = ops.get_collection(ops.GraphKeys.VARIABLES)
+ if vars_at_start != len(variables):
+ logging.info("New variables created when calling a template after "
+ "the first time, perhaps you used tf.Variable when you "
+ "meant tf.get_variable: %s",
+ variables[vars_at_start:])
+ return result
+ except Exception, exc:
+ # Reraise the exception, but append the original definition to the
+ # trace.
+ args = exc.args
+ if not args:
+ arg0 = ""
+ else:
+ arg0 = args[0]
+ trace = "".join(_skip_common_stack_elements(self._stacktrace,
+ traceback.format_stack()))
+ arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
+ new_args = [arg0]
+ new_args.extend(args[1:])
+ exc.args = tuple(new_args)
+ raise
+
+ def __call__(self, *args, **kwargs):
+ # Capture the name of the variable_scope here because if we capture at
+ # construction, then name_scopes would have a '_N+1' suffix.
+ if self._var_scope:
+ with variable_scope.variable_scope(self._var_scope, reuse=True):
+ return self._call_func(args, kwargs, check_for_new_variables=True)
+ else:
+ with variable_scope.variable_op_scope([], None, self._name) as vs:
+ self._var_scope = vs
+ return self._call_func(args, kwargs, check_for_new_variables=False)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index adfed93e66..4756aec896 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -284,10 +284,10 @@ class Variable(object):
with tf.Session() as sess:
sess.run(init)
# Usage passing the session explicitly.
- print v.eval(sess)
+ print(v.eval(sess))
# Usage with the default session. The 'with' block
# above makes 'sess' the default session.
- print v.eval()
+ print(v.eval())
```
Args:
diff --git a/tensorflow/python/summary/event_accumulator.py b/tensorflow/python/summary/event_accumulator.py
index ab5fb4a426..56bcd15d2b 100644
--- a/tensorflow/python/summary/event_accumulator.py
+++ b/tensorflow/python/summary/event_accumulator.py
@@ -177,7 +177,7 @@ class EventAccumulator(object):
## file_version events always have step 0, ignore.
## TODO(danmane): Have this check for restart events explicitly
if (event.step < self.most_recent_step and
- not event.HasField('file_version')):
+ event.HasField('summary')):
## Keep data in reservoirs that has a step less than event.step
_NotExpired = lambda x: x.step < event.step
diff --git a/tensorflow/python/summary/event_accumulator_test.py b/tensorflow/python/summary/event_accumulator_test.py
index a64084d826..394a6d2290 100644
--- a/tensorflow/python/summary/event_accumulator_test.py
+++ b/tensorflow/python/summary/event_accumulator_test.py
@@ -24,6 +24,7 @@ import tensorflow.python.platform
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
+from tensorflow.core.framework import graph_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.summary import event_accumulator as ea
@@ -391,13 +392,15 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
## Check that we have discarded 200 and 300
self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301])
- def testFileVersionEventDoesntTriggerDiscard(self):
+ def testOnlySummaryEventsTriggerDiscards(self):
"""Test that file version event doesnt trigger data purge."""
gen = _EventGenerator()
acc = ea.EventAccumulator(gen)
gen.AddScalar('s1', wall_time=1, step=100, value=20)
- ev = tf.Event(wall_time=2, step=0, file_version='0')
- gen.AddEvent(ev)
+ ev1 = tf.Event(wall_time=2, step=0, file_version='0')
+ ev2 = tf.Event(wall_time=3, step=0, graph_def=graph_pb2.GraphDef())
+ gen.AddEvent(ev1)
+ gen.AddEvent(ev2)
acc.Reload()
self.assertEqual([x.step for x in acc.Scalars('s1')], [100])
diff --git a/tensorflow/python/summary/impl/event_file_loader.py b/tensorflow/python/summary/impl/event_file_loader.py
index bd5deb0e0a..7c7f6ca70d 100644
--- a/tensorflow/python/summary/impl/event_file_loader.py
+++ b/tensorflow/python/summary/impl/event_file_loader.py
@@ -49,7 +49,6 @@ class EventFileLoader(object):
All values that were written to disk that have not been yielded yet.
"""
while self._reader.GetNext():
- logging.debug('Got an event from %s', self._file_path)
event = event_pb2.Event()
event.ParseFromString(self._reader.record())
yield event
diff --git a/tensorflow/python/training/summary_io.py b/tensorflow/python/training/summary_io.py
index f10e984e6d..12b0548d46 100644
--- a/tensorflow/python/training/summary_io.py
+++ b/tensorflow/python/training/summary_io.py
@@ -217,7 +217,7 @@ def summary_iterator(path):
```python
for e in tf.summary_iterator(path to events file):
- print e
+ print(e)
```
Example: Print selected summary values.
@@ -230,7 +230,7 @@ def summary_iterator(path):
for e in tf.summary_iterator(path to events file):
for v in e.summary.value:
if v.tag == 'loss':
- print v.simple_value
+ print(v.simple_value)
```
See the protocol buffer definitions of
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 3d2b6d9c84..54f8cea031 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -33,7 +33,7 @@ def global_step(sess, global_step_tensor):
sess = tf.Session()
# Initializes the variable.
sess.run(global_step_tensor.initializer)
- print 'global_step:', tf.train.global_step(sess, global_step_tensor)
+ print('global_step:', tf.train.global_step(sess, global_step_tensor))
global_step: 10
```
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index fbc9342081..4dc60aea53 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -162,6 +162,7 @@ string BatchDescriptor::ToShortString() const {
return port::StrCat(batch, depth, y, x, suffix);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
+ return ""; // Avoid return warning (unreachable)
}
}
@@ -243,6 +244,7 @@ string FilterDescriptor::ToShortString() const {
return port::StrCat(y, x, id, od);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_);
+ return ""; // Avoid return warning (unreachable)
}
}
diff --git a/tensorflow/tensorboard/CHANGES b/tensorflow/tensorboard/CHANGES
index 88e252f457..7e3083f58e 100644
--- a/tensorflow/tensorboard/CHANGES
+++ b/tensorflow/tensorboard/CHANGES
@@ -3,4 +3,8 @@ Begin tracking TensorBoard changes.
--- 3 ---
Change default # of scalar values to 1000
-Fix bug where TensorBoard discards all values after a restart. \ No newline at end of file
+Fix bug where TensorBoard discards all values after a restart.
+
+--- 4 ---
+Fix another case where TensorBoard discards values after a restart.
+Add a "toggle all runs" button. \ No newline at end of file
diff --git a/tensorflow/tensorboard/TAG b/tensorflow/tensorboard/TAG
index 00750edc07..b8626c4cff 100644
--- a/tensorflow/tensorboard/TAG
+++ b/tensorflow/tensorboard/TAG
@@ -1 +1 @@
-3
+4