aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-02-17 17:05:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-17 17:23:48 -0800
commit93a975e114ee1c35f01ed3bdd47170e6f7129014 (patch)
treee34255aff698fe6a4a586e7940337fd278947f58 /tensorflow/python
parenteb9624017a0040e805fda622a5f9ec6681e24246 (diff)
Merge changes from github.
Change: 147897309
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD36
-rw-r--r--tensorflow/python/framework/errors_impl.py2
-rw-r--r--tensorflow/python/framework/function.py4
-rw-r--r--tensorflow/python/framework/op_def_library.py4
-rw-r--r--tensorflow/python/framework/op_def_library_test.py3
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py6
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py6
-rw-r--r--tensorflow/python/lib/io/file_io.py4
-rw-r--r--tensorflow/python/lib/io/file_io_test.py1
-rw-r--r--tensorflow/python/ops/array_grad.py1
-rw-r--r--tensorflow/python/ops/array_ops.py4
-rw-r--r--tensorflow/python/ops/control_flow_ops.py46
-rw-r--r--tensorflow/python/ops/variables.py5
-rw-r--r--tensorflow/python/training/supervisor.py2
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer_test.py1
15 files changed, 89 insertions, 36 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 5353035b18..04e1afaf81 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2636,7 +2636,6 @@ cuda_py_tests(
"training/proximal_gradient_descent_test.py",
"training/queue_runner_test.py",
"training/rmsprop_test.py",
- "training/saver_test.py",
"training/slot_creator_test.py",
"training/tensorboard_logging_test.py",
"training/training_ops_test.py",
@@ -2678,6 +2677,41 @@ cuda_py_tests(
],
)
+cuda_py_test(
+ name = "saver_test",
+ size = "medium",
+ srcs = [
+ "training/saver_test.py",
+ ],
+ additional_deps = [
+ ":array_ops",
+ ":client_testlib",
+ ":control_flow_ops",
+ ":data_flow_ops",
+ ":data_flow_ops_gen",
+ ":errors",
+ ":gradients",
+ ":math_ops",
+ ":nn_grad",
+ ":nn_ops",
+ ":partitioned_variables",
+ ":platform",
+ ":platform_test",
+ ":pywrap_tensorflow",
+ ":random_ops",
+ ":resource_variable_ops",
+ ":sparse_ops",
+ ":summary",
+ ":training",
+ ":util",
+ ":variable_scope",
+ ":variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
py_test(
name = "saver_large_variable_test",
size = "small",
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index 04a6e4d7fb..32c96ec947 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -456,8 +456,8 @@ def _make_specific_exception(node_def, op, message, error_code):
@contextlib.contextmanager
def raise_exception_on_not_ok_status():
+ status = pywrap_tensorflow.TF_NewStatus()
try:
- status = pywrap_tensorflow.TF_NewStatus()
yield status
if pywrap_tensorflow.TF_GetCode(status) != 0:
raise _make_specific_exception(
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 46da2646ec..7c0201f93e 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -769,6 +769,10 @@ class Defun(object):
default graph and adds the definition of the function into the
default graph. Because the addition of the function into the graph
is deferred, the decorator can be used anywhere in the program.
+
+ Definitions of functions are frozen in a graph as soon as the graph is used to
+ create a session. Therefore, nodes using the function must be created in the
+ graph before the corresponding session is created.
Example, but also see the [How To on functions](link_needed).
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index cb79954226..7f2b03e350 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -618,8 +618,8 @@ class OpDefLibrary(object):
if input_arg.is_ref:
if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access
raise TypeError(
- "Input '%s' of '%s' Op requires l-value input" %
- (input_name, op_type_name))
+ ("'%s' Op requires that input '%s' be a mutable tensor " +
+ "(e.g.: a tf.Variable)") % (op_type_name, input_name))
input_types.extend(types)
else:
input_types.extend(base_types)
diff --git a/tensorflow/python/framework/op_def_library_test.py b/tensorflow/python/framework/op_def_library_test.py
index 0fc7f0b353..715e863b78 100644
--- a/tensorflow/python/framework/op_def_library_test.py
+++ b/tensorflow/python/framework/op_def_library_test.py
@@ -1462,7 +1462,8 @@ class OpDefLibraryTest(test_util.TensorFlowTestCase):
with self.assertRaises(TypeError) as cm:
self._lib.apply_op("RefIn", a=2)
self.assertEqual(str(cm.exception),
- "Input 'a' of 'RefIn' Op requires l-value input")
+ "'RefIn' Op requires that input 'a' be a mutable tensor " +
+ "(e.g.: a tf.Variable)")
input_a = self._lib.apply_op("RefOut", T=dtypes.int32, name="t")
input_b = self._lib.apply_op("RefOut", T=dtypes.int32, name="u")
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index fe93a30668..128a6529f0 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -771,6 +771,12 @@ class PlaceholderWithDefaultTest(test.TestCase):
self.assertAllEqual(
[[3, 3], [3, 3]], a.eval(feed_dict={p: [[3, 3], [3, 3]]}))
+ def testGradient(self):
+ with self.test_session():
+ x = array_ops.placeholder(dtypes_lib.float32, [5, 7])
+ y = array_ops.placeholder_with_default(x, None)
+ err = gradient_checker.compute_gradient_error(x, [5, 7], y, [5, 7])
+ self.assertLess(err, 1e-3)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index b96d16e54b..11b350a99e 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -419,6 +419,12 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(np.ones((5, 5), np.float32), var.eval())
+ def testRepr(self):
+ var = variables.Variable(np.zeros((5, 5), np.float32), name='noop')
+ self.assertEqual(
+ "<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>",
+ repr(var))
+
class IsInitializedTest(test.TestCase):
diff --git a/tensorflow/python/lib/io/file_io.py b/tensorflow/python/lib/io/file_io.py
index ddd117e443..ace03e3d1b 100644
--- a/tensorflow/python/lib/io/file_io.py
+++ b/tensorflow/python/lib/io/file_io.py
@@ -146,9 +146,7 @@ class FileIO(object):
def tell(self):
"""Returns the current position in the file."""
- if not self._read_check_passed:
- raise errors.PermissionDeniedError(None, None,
- "File isn't open for reading")
+ self._preread_check()
return self._read_buf.Tell()
def __enter__(self):
diff --git a/tensorflow/python/lib/io/file_io_test.py b/tensorflow/python/lib/io/file_io_test.py
index 0063eebb59..72931217d9 100644
--- a/tensorflow/python/lib/io/file_io_test.py
+++ b/tensorflow/python/lib/io/file_io_test.py
@@ -354,6 +354,7 @@ class FileIoTest(test.TestCase):
file_path = os.path.join(self._base_dir, "temp_file")
with file_io.FileIO(file_path, mode="r+") as f:
f.write("testing1\ntesting2\ntesting3\n\ntesting5")
+ self.assertEqual(0, f.tell())
self.assertEqual("testing1\n", f.readline())
self.assertEqual(9, f.tell())
self.assertEqual("testing2\n", f.readline())
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index fa1dda29ad..8d66452b4b 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -382,6 +382,7 @@ def _CheckNumericsGrad(_, grad):
grad, "Not a number (NaN) or infinity (Inf) values detected in gradient.")
+@ops.RegisterGradient("PlaceholderWithDefault")
@ops.RegisterGradient("Identity")
def _IdGrad(_, grad):
return grad
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 61cb5666a7..97f80a3f06 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -616,8 +616,8 @@ def strided_slice(input_,
tf.strided_slice(input, [1, 0, 0], [2, 1, 3], [1, 1, 1]) ==> [[[3, 3, 3]]]
tf.strided_slice(input, [1, 0, 0], [2, 2, 3], [1, 1, 1]) ==> [[[3, 3, 3],
[4, 4, 4]]]
- tf.strided_slice(input, [1, 1, 0], [2, -1, 3], [1, -1, 1]) ==>[[[4, 4, 4],
- [3, 3, 3]]]
+ tf.strided_slice(input, [1, -1, 0], [2, -3, 3], [1, -1, 1]) ==>[[[4, 4, 4],
+ [3, 3, 3]]]
```
Args:
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 2f308b170b..243c4ed033 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -2572,35 +2572,35 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
Example:
- ```python
- i = tf.constant(0)
- c = lambda i: tf.less(i, 10)
- b = lambda i: tf.add(i, 1)
- r = tf.while_loop(c, b, [i])
- ```
+ ```python
+ i = tf.constant(0)
+ c = lambda i: tf.less(i, 10)
+ b = lambda i: tf.add(i, 1)
+ r = tf.while_loop(c, b, [i])
+ ```
Example with nesting and a namedtuple:
- ```python
- import collections
- Pair = collections.namedtuple('Pair', 'j, k')
- ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
- c = lambda i, p: i < 10
- b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
- ijk_final = tf.while_loop(c, b, ijk_0)
- ```
+ ```python
+ import collections
+ Pair = collections.namedtuple('Pair', 'j, k')
+ ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
+ c = lambda i, p: i < 10
+ b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
+ ijk_final = tf.while_loop(c, b, ijk_0)
+ ```
Example using shape_invariants:
- ```python
- i0 = tf.constant(0)
- m0 = tf.ones([2, 2])
- c = lambda i, m: i < 10
- b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
- tf.while_loop(
- c, b, loop_vars=[i0, m0],
- shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
- ```
+ ```python
+ i0 = tf.constant(0)
+ m0 = tf.ones([2, 2])
+ c = lambda i, m: i < 10
+ b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
+ tf.while_loop(
+ c, b, loop_vars=[i0, m0],
+ shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
+ ```
"""
with ops.name_scope(name, "while", loop_vars) as name:
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index dc73ad78a7..5a1a43b5d5 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -196,8 +196,9 @@ class Variable(object):
dtype=dtype,
expected_shape=expected_shape)
- def __str__(self):
- return str(self._snapshot)
+ def __repr__(self):
+ return "<tf.Variable '%s' shape=%s dtype=%s>" % (
+ self.name, self.get_shape(), self.dtype.name)
def _init_from_args(self,
initial_value=None,
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 7b93391911..884c1b6182 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -1027,7 +1027,7 @@ class SVStepCounterThread(coordinator.LooperThread):
elapsed_time = current_time - self._last_time
self._last_time = current_time
# Reports the number of steps done per second
- steps_per_sec = added_steps / elapsed_time
+ steps_per_sec = added_steps / elapsed_time if elapsed_time != 0. else float("inf")
summary = Summary(value=[Summary.Value(tag=self._summary_tag,
simple_value=steps_per_sec)])
if self._sv.summary_writer:
diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py
index 32cae70460..15f938df8c 100644
--- a/tensorflow/python/training/sync_replicas_optimizer_test.py
+++ b/tensorflow/python/training/sync_replicas_optimizer_test.py
@@ -267,6 +267,7 @@ class SyncReplicasOptimizerTest(test.TestCase):
# Starts worker 1.
thread_1.start()
thread_1.join()
+ thread_0.join()
# The global step should now be 2 and the gradients should have been
# applied again.