aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-03-21 12:07:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 12:10:30 -0700
commit2d0531d72c7dcbb0e149cafdd3a16ee8c3ff357a (patch)
tree1179ecdd684d10c6549f85aa95f33dd79463a093 /tensorflow/python
parentcbede3ea7574b36f429710bc08617d08455bcc21 (diff)
Merge changes from github.
PiperOrigin-RevId: 189945839
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD12
-rw-r--r--tensorflow/python/client/timeline_test.py7
-rw-r--r--tensorflow/python/estimator/estimator.py34
-rw-r--r--tensorflow/python/estimator/run_config.py2
-rw-r--r--tensorflow/python/estimator/training.py26
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/generic_utils.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/vis_utils.py2
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py11
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py20
-rw-r--r--tensorflow/python/kernel_tests/depthtospace_op_test.py10
-rw-r--r--tensorflow/python/kernel_tests/spacetodepth_op_test.py10
-rw-r--r--tensorflow/python/layers/base.py2
-rw-r--r--tensorflow/python/layers/normalization.py9
-rw-r--r--tensorflow/python/lib/io/file_io_test.py5
-rw-r--r--tensorflow/python/lib/io/tf_record.py18
-rw-r--r--tensorflow/python/ops/linalg_ops.py2
-rw-r--r--tensorflow/python/ops/nn_ops.py16
-rw-r--r--tensorflow/python/ops/random_ops.py2
-rw-r--r--tensorflow/python/ops/rnn.py17
-rw-r--r--tensorflow/python/ops/special_math_ops.py4
-rw-r--r--tensorflow/python/ops/special_math_ops_test.py5
-rw-r--r--tensorflow/python/tools/freeze_graph.py36
-rw-r--r--tensorflow/python/tools/inspect_checkpoint.py4
-rw-r--r--tensorflow/python/tools/saved_model_cli.py60
-rw-r--r--tensorflow/python/tools/saved_model_cli_test.py22
-rw-r--r--tensorflow/python/training/saver.py5
28 files changed, 250 insertions, 101 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index d11ee6f74c..54e944c264 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -86,7 +86,6 @@ py_library(
":ops",
":platform",
":pywrap_tensorflow",
- ":saver_test_utils",
":script_ops",
":session_ops",
":sets",
@@ -96,14 +95,15 @@ py_library(
":standard_ops",
":state_ops",
":string_ops",
- ":subscribe",
":summary",
":tensor_array_ops",
+ ":training",
+ ":saver_test_utils",
+ ":subscribe",
":test_ops", # TODO: Break testing code out into separate rule.
- ":tf_cluster",
":tf_item",
+ ":tf_cluster",
":tf_optimizer",
- ":training",
":util",
":weights_broadcast_ops",
"//third_party/py/numpy",
@@ -3971,7 +3971,11 @@ py_test(
srcs = ["training/checkpoint_utils_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "manual",
+ "no_cuda_on_cpu_tap",
+ "no_oss",
"no_windows",
+ "notap",
],
deps = [
":client",
diff --git a/tensorflow/python/client/timeline_test.py b/tensorflow/python/client/timeline_test.py
index 9641b8b7f2..5e6b5acdb0 100644
--- a/tensorflow/python/client/timeline_test.py
+++ b/tensorflow/python/client/timeline_test.py
@@ -155,9 +155,12 @@ class TimelineTest(test.TestCase):
ctf = step_analysis.chrome_trace.format_to_string()
self._validateTrace(ctf)
maximums = step_analysis.allocator_maximums
- self.assertTrue('cpu' in maximums)
+ cpuname = 'cpu'
+ if 'mklcpu' in maximums:
+ cpuname = 'mkl' + cpuname
+ self.assertTrue(cpuname in maximums)
cpu_max = maximums[
- 'cuda_host_bfc'] if 'cuda_host_bfc' in maximums else maximums['cpu']
+ 'cuda_host_bfc'] if 'cuda_host_bfc' in maximums else maximums[cpuname]
# At least num1 + num2, both float32s (4 bytes each)
self.assertGreater(cpu_max.num_bytes, 8)
self.assertGreater(cpu_max.timestamp, 0)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 9fcbd4ff77..6a4132bca2 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -139,8 +139,8 @@ class Estimator(object):
to configure Estimators from hyper parameter tuning.
* `config`: Optional configuration object. Will receive what is passed
to Estimator in `config` parameter, or the default `config`.
- Allows updating things in your model_fn based on configuration
- such as `num_ps_replicas`, or `model_dir`.
+ Allows updating things in your `model_fn` based on
+ configuration such as `num_ps_replicas`, or `model_dir`.
* Returns:
`EstimatorSpec`
@@ -301,11 +301,11 @@ class Estimator(object):
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where features is a `Tensor` or a
- dictionary of string feature name to `Tensor` and labels is a
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
`Tensor` or a dictionary of string label name to `Tensor`. Both
- features and labels are consumed by `model_fn`. They should satisfy
- the expectation of `model_fn` from inputs.
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the training loop.
@@ -381,11 +381,11 @@ class Estimator(object):
* A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
tuple (features, labels) with same constraints as below.
- * A tuple (features, labels): Where features is a `Tensor` or a
- dictionary of string feature name to `Tensor` and labels is a
+ * A tuple (features, labels): Where `features` is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and `labels` is a
`Tensor` or a dictionary of string label name to `Tensor`. Both
- features and labels are consumed by `model_fn`. They should satisfy
- the expectation of `model_fn` from inputs.
+ `features` and `labels` are consumed by `model_fn`. They should
+ satisfy the expectation of `model_fn` from inputs.
steps: Number of steps for which to evaluate model. If `None`, evaluates
until `input_fn` raises an end-of-input exception.
@@ -457,17 +457,17 @@ class Estimator(object):
checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
latest checkpoint in `model_dir` is used.
yield_single_examples: If False, yield the whole batch as returned by the
- model_fn instead of decomposing the batch into individual elements. This
- is useful if model_fn return some tensor with first dimension not
- equal to the batch size
+ `model_fn` instead of decomposing the batch into individual elements.
+ This is useful if `model_fn` returns some tensors whose first dimension
+ is not equal to the batch size.
Yields:
Evaluated values of `predictions` tensors.
Raises:
- ValueError: Could not find a trained model in model_dir.
- ValueError: if batch length of predictions are not same and
- yield_single_examples is True.
+ ValueError: Could not find a trained model in `model_dir`.
+ ValueError: If batch length of predictions is not the same and
+ `yield_single_examples` is True.
ValueError: If there is a conflict between `predict_keys` and
`predictions`. For example if `predict_keys` is not `None` but
`EstimatorSpec.predictions` is not a `dict`.
@@ -849,7 +849,7 @@ class Estimator(object):
'loss': estimator_spec.loss,
'step': global_step_tensor
},
- every_n_iter=100)
+ every_n_iter=self._config.log_step_count_steps)
])
worker_hooks.extend(estimator_spec.training_hooks)
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 62f035bce5..820fda7765 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -423,7 +423,7 @@ class RunConfig(object):
to be saved. The default value of 10,000 hours effectively disables
the feature.
log_step_count_steps: The frequency, in number of global steps, that the
- global step/sec will be logged during training.
+ global step/sec and the loss will be logged during training.
Raises:
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 2cc3331a15..e38b765da5 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -128,9 +128,16 @@ class TrainSpec(
"""Creates a validated `TrainSpec` instance.
Args:
- input_fn: Training input function returning a tuple of:
- features - `Tensor` or dictionary of string feature name to `Tensor`.
- labels - `Tensor` or dictionary of `Tensor` with labels.
+ input_fn: A function that provides input data for training as minibatches.
+ See @{$get_started/premade_estimators#create_input_functions} for more
+ information. The function should construct and return one of
+ the following:
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where features is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and labels is a
+ `Tensor` or a dictionary of string label name to `Tensor`.
+
max_steps: Int. Positive number of total steps for which to train model.
If `None`, train forever. The training `input_fn` is not expected to
generate `OutOfRangeError` or `StopIteration` exceptions. See the
@@ -185,9 +192,16 @@ class EvalSpec(
"""Creates a validated `EvalSpec` instance.
Args:
- input_fn: Evaluation input function returning a tuple of:
- features - `Tensor` or dictionary of string feature name to `Tensor`.
- labels - `Tensor` or dictionary of `Tensor` with labels.
+ input_fn: A function that constructs the input data for evaluation.
+ See @{$get_started/premade_estimators#create_input_functions} for more
+ information. The function should construct and return one of
+ the following:
+ * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a
+ tuple (features, labels) with same constraints as below.
+ * A tuple (features, labels): Where features is a `Tensor` or a
+ dictionary of string feature name to `Tensor` and labels is a
+ `Tensor` or a dictionary of string label name to `Tensor`.
+
steps: Int. Positive number of steps for which to evaluate model. If
`None`, evaluates until `input_fn` raises an end-of-input exception.
See `Estimator.evaluate` for details.
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 57506f9aff..4acb41553e 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -266,7 +266,7 @@ class Model(Network):
# initialization for Eager mode execution
if context.executing_eagerly():
if target_tensors is not None:
- raise ValueError('target_tensors are not currently supported in Eager'
+ raise ValueError('target_tensors are not currently supported in Eager '
'mode.')
self.total_loss = None
self.metrics_tensors = []
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 2910719807..791f9b3113 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -546,8 +546,8 @@ class RNN(Layer):
raise ValueError('The initial state or constants of an RNN'
' layer cannot be specified with a mix of'
' Keras tensors and non-Keras tensors'
- '(a "Keras tensor" is a tensor that was'
- 'returned by a Keras layer, or by `Input`)')
+ ' (a "Keras tensor" is a tensor that was'
+ ' returned by a Keras layer, or by `Input`)')
if is_keras_tensor:
# Compute the full input spec, including state and constants
diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
index 5196bf1740..3bbe87f92d 100644
--- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
@@ -490,8 +490,8 @@ def slice_arrays(arrays, start=None, stop=None):
if arrays is None:
return [None]
if isinstance(start, list) and stop is not None:
- raise ValueError('The stop argument has to be None if the value of start is'
- 'a list.')
+ raise ValueError('The stop argument has to be None if the value of start '
+ 'is a list.')
elif isinstance(arrays, list):
if hasattr(start, '__len__'):
# hdf5 datasets only support list objects as indices
diff --git a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
index 45c1b92075..4761cece82 100644
--- a/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/vis_utils.py
@@ -120,7 +120,7 @@ def model_to_dot(model, show_shapes=False, show_layer_names=True, rankdir='TB'):
layer_id = str(id(layer))
for i, node in enumerate(layer._inbound_nodes):
node_key = layer.name + '_ib-' + str(i)
- if node_key in model._container_nodes:
+ if node_key in model._network_nodes: # pylint: disable=protected-access
for inbound_layer in node.inbound_layers:
inbound_layer_id = str(id(inbound_layer))
layer_id = str(id(layer))
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index 81c6a4aa6e..c22934ce47 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -606,6 +606,17 @@ class ConcatOpTest(test.TestCase):
inp_tensors_placeholders, -2, output_shape=[2, 3],
gather_indexes=[2, 0], feed_dict=feed_dict)
+ def testConcatAxisType(self):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=True):
+ t1 = [[1, 2, 3], [4, 5, 6]]
+ t2 = [[7, 8, 9], [10, 11, 12]]
+
+ c = gen_array_ops.concat_v2([t1, t2],
+ constant_op.constant(1, dtype=dtype))
+ self.assertEqual([2, 6], c.get_shape().as_list())
+ output = c.eval()
+ self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output)
class ConcatOffsetTest(test.TestCase):
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index f4fe01f868..25525cc128 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -970,7 +970,7 @@ class Conv2DTest(test.TestCase):
self.assertArrayNear(value_2.flatten(), value.flatten(), err)
def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropFilterDilation(
input_sizes=[1, 3, 6, 1],
@@ -984,7 +984,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropFilterDilation(
input_sizes=[1, 2, 3, 1],
@@ -998,7 +998,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2DEmptyBackpropFilterDilation1x2(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropFilterDilation(
input_sizes=[1, 2, 3, 1],
@@ -1012,7 +1012,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropFilterDilation(
input_sizes=[1, 3, 4, 3],
@@ -1026,7 +1026,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropFilterDilation(
input_sizes=[1, 3, 3, 1],
@@ -1040,7 +1040,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropInputDilation(
input_sizes=[1, 3, 6, 1],
@@ -1054,7 +1054,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropInputDilation(
input_sizes=[1, 2, 3, 1],
@@ -1068,7 +1068,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2DEmptyBackpropInputDilation1x2(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropInputDilation(
input_sizes=[0, 2, 3, 1],
@@ -1082,7 +1082,7 @@ class Conv2DTest(test.TestCase):
err=1e-5)
def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
# The GPU version of this test is not very stable. So adjusting the
# error threshold to 1e-4.
@@ -1098,7 +1098,7 @@ class Conv2DTest(test.TestCase):
err=1e-4)
def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self):
- if test.is_gpu_available(cuda_only=True):
+ if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled():
for (data_format, use_gpu) in GetTestConfigs():
self._RunAndVerifyBackpropInputDilation(
input_sizes=[1, 3, 3, 1],
diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py
index 96c9718b83..f0beabb4e2 100644
--- a/tensorflow/python/kernel_tests/depthtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py
@@ -35,8 +35,8 @@ from tensorflow.python.platform import tf_logging
class DepthToSpaceTest(test.TestCase):
- def _testOne(self, inputs, block_size, outputs):
- input_nhwc = math_ops.to_float(inputs)
+ def _testOne(self, inputs, block_size, outputs, dtype=dtypes.float32):
+ input_nhwc = math_ops.cast(inputs, dtype)
with self.test_session(use_gpu=False):
# test NHWC (default) on CPU
x_tf = array_ops.depth_to_space(input_nhwc, block_size)
@@ -59,6 +59,12 @@ class DepthToSpaceTest(test.TestCase):
x_out = [[[[1], [2]], [[3], [4]]]]
self._testOne(x_np, block_size, x_out)
+ def testBasicFloat16(self):
+ x_np = [[[[1, 2, 3, 4]]]]
+ block_size = 2
+ x_out = [[[[1], [2]], [[3], [4]]]]
+ self._testOne(x_np, block_size, x_out, dtype=dtypes.float16)
+
# Tests for larger input dimensions. To make sure elements are
# correctly ordered spatially.
def testBlockSize2(self):
diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
index b76135764f..cd90d16aac 100644
--- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
@@ -34,8 +34,8 @@ from tensorflow.python.platform import tf_logging
class SpaceToDepthTest(test.TestCase):
- def _testOne(self, inputs, block_size, outputs):
- input_nhwc = math_ops.to_float(inputs)
+ def _testOne(self, inputs, block_size, outputs, dtype=dtypes.float32):
+ input_nhwc = math_ops.cast(inputs, dtype)
with self.test_session(use_gpu=False):
# test NHWC (default) on CPU
x_tf = array_ops.space_to_depth(input_nhwc, block_size)
@@ -58,6 +58,12 @@ class SpaceToDepthTest(test.TestCase):
x_out = [[[[1, 2, 3, 4]]]]
self._testOne(x_np, block_size, x_out)
+ def testBasicFloat16(self):
+ x_np = [[[[1], [2]], [[3], [4]]]]
+ block_size = 2
+ x_out = [[[[1, 2, 3, 4]]]]
+ self._testOne(x_np, block_size, x_out, dtype=dtypes.float16)
+
# Tests for larger input dimensions. To make sure elements are
# correctly ordered spatially.
def testLargerInput2x2(self):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index e9066d3fda..e4395bea92 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -578,7 +578,7 @@ class Layer(checkpointable.CheckpointableBase):
if isinstance(variable, tf_variables.PartitionedVariable):
raise RuntimeError(
'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
+ 'supported when executing eagerly. File a feature request '
'if this is important to you.')
# Save a zero-argument lambda which runs the regularizer on the
# variable, to be executed when `Layer.losses` is requested.
diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py
index 11daf01670..29fb92ccb5 100644
--- a/tensorflow/python/layers/normalization.py
+++ b/tensorflow/python/layers/normalization.py
@@ -664,9 +664,16 @@ def batch_normalization(inputs,
Note: when training, the moving_mean and moving_variance need to be updated.
By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
- need to be added as a dependency to the `train_op`. For example:
+ need to be added as a dependency to the `train_op`. Also, be sure to add
+ any batch_normalization ops before getting the update_ops collection.
+ Otherwise, update_ops will be empty, and training/inference will not work
+ properly. For example:
```python
+ x_norm = tf.layers.batch_normalization(x, training=training)
+
+ # ...
+
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
diff --git a/tensorflow/python/lib/io/file_io_test.py b/tensorflow/python/lib/io/file_io_test.py
index a751607aaa..223858edfa 100644
--- a/tensorflow/python/lib/io/file_io_test.py
+++ b/tensorflow/python/lib/io/file_io_test.py
@@ -485,6 +485,11 @@ class FileIoTest(test.TestCase):
f.flush()
self.assertEqual(content, f.read(len(content) + 1))
+ def testUTF8StringPathExists(self):
+ file_path = os.path.join(self._base_dir, "UTF8测试_file_exist")
+ file_io.write_string_to_file(file_path, "testing")
+ v = file_io.file_exists(file_path)
+ self.assertEqual(v, True)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index 48ea107a14..6fcf9c91d8 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -75,14 +75,16 @@ def tf_record_iterator(path, options=None):
if reader is None:
raise IOError("Could not open %s." % path)
- while True:
- try:
- with errors.raise_exception_on_not_ok_status() as status:
- reader.GetNext(status)
- except errors.OutOfRangeError:
- break
- yield reader.record()
- reader.Close()
+ try:
+ while True:
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ reader.GetNext(status)
+ except errors.OutOfRangeError:
+ break
+ yield reader.record()
+ finally:
+ reader.Close()
@tf_export("python_io.TFRecordWriter")
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 37470e00d7..5b4fb4f7c8 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -341,7 +341,7 @@ def self_adjoint_eig(tensor, name=None):
name: string, optional name of the operation.
Returns:
- e: Eigenvalues. Shape is `[..., N]`.
+ e: Eigenvalues. Shape is `[..., N]`. Sorted in non-decreasing order.
v: Eigenvectors. Shape is `[..., N, N]`. The columns of the inner most
matrices contain eigenvectors of the corresponding matrices in `tensor`
"""
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index fb3fe77b4d..a74de39eab 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -150,14 +150,12 @@ class _NonAtrousConvolution(object):
conv_dims))
if conv_dims == 1:
# conv1d uses the 2-d data format names
- if data_format is None or data_format == "NWC":
- data_format_2d = "NHWC"
- elif data_format == "NCW":
- data_format_2d = "NCHW"
- else:
+ if data_format is None:
+ data_format = "NWC"
+ elif data_format not in {"NCW", "NWC", "NCHW", "NHWC"}:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
self.strides = strides[0]
- self.data_format = data_format_2d
+ self.data_format = data_format
self.conv_op = self._conv1d
elif conv_dims == 2:
if data_format is None or data_format == "NHWC":
@@ -699,7 +697,7 @@ def convolution(
`padded_input` is obtained by zero padding the input using an effective
spatial filter shape of `(spatial_filter_shape-1) * dilation_rate + 1` and
output striding `strides` as described in the
- @{tf.nn.convolution$comment here}.
+ @{$python/nn#Convolution$comment here}.
In the case that `data_format` does start with `"NC"`, the `input` and output
(but not the `filter`) are simply transposed as follows:
@@ -1043,9 +1041,7 @@ def pool(
@tf_export("nn.atrous_conv2d")
def atrous_conv2d(value, filters, rate, padding, name=None):
- """Atrous convolution (a.k.a.
-
- convolution with holes or dilated convolution).
+ """Atrous convolution (a.k.a. convolution with holes or dilated convolution).
This function is a simpler wrapper around the more general
@{tf.nn.convolution}, and exists only for backwards compatibility. You can
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index db8159579a..6a2dd3f1cd 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -209,7 +209,7 @@ def random_uniform(shape,
maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on
the range of random values to generate. Defaults to 1 if `dtype` is
floating point.
- dtype: The type of the output: 'float16`, `float32`, `float64`, `int32`,
+ dtype: The type of the output: `float16`, `float32`, `float64`, `int32`,
or `int64`.
seed: A Python integer. Used to create a random seed for the distribution.
See @{tf.set_random_seed}
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index c59eccc174..42af7f8b27 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -867,7 +867,7 @@ def raw_rnn(cell, loop_fn,
```python
time = tf.constant(0, dtype=tf.int32)
- (finished, next_input, initial_state, _, loop_state) = loop_fn(
+ (finished, next_input, initial_state, emit_structure, loop_state) = loop_fn(
time=time, cell_output=None, cell_state=None, loop_state=None)
emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype)
state = initial_state
@@ -878,7 +878,7 @@ def raw_rnn(cell, loop_fn,
loop_state=loop_state)
# Emit zeros and copy forward state for minibatch entries that are finished.
state = tf.where(finished, state, next_state)
- emit = tf.where(finished, tf.zeros_like(emit), emit)
+ emit = tf.where(finished, tf.zeros_like(emit_structure), emit)
emit_ta = emit_ta.write(time, emit)
# If any new minibatch entries are marked as finished, mark these.
finished = tf.logical_or(finished, next_finished)
@@ -938,10 +938,15 @@ def raw_rnn(cell, loop_fn,
and `emit_output`: the output to store for this iteration.
Note that `emit_output` should be a `Tensor` or (possibly nested)
- tuple of tensors with shapes and structure matching `cell.output_size`
- and `cell_output` above. The parameter `cell_state` and output
- `next_cell_state` may be either a single or (possibly nested) tuple
- of tensors. The parameter `loop_state` and
+ tuple of tensors which is aggregated in the `emit_ta` inside the
+ `while_loop`. For the first call to `loop_fn`, the `emit_output`
+ corresponds to the `emit_structure` which is then used to determine the
+ size of the `zero_tensor` for the `emit_ta` (defaults to
+ `cell.output_size`). For the subsequent calls to the `loop_fn`, the
+ `emit_output` corresponds to the actual output tensor
+ that is to be aggregated in the `emit_ta`. The parameter `cell_state`
+ and output `next_cell_state` may be either a single or (possibly nested)
+ tuple of tensors. The parameter `loop_state` and
output `next_loop_state` may be either a single or (possibly nested) tuple
of `Tensor` and `TensorArray` objects. This last parameter
may be ignored by `loop_fn` and the return value may be `None`. If it
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 6d7eaababc..5e2146b79f 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -163,7 +163,7 @@ def einsum(equation, *inputs, **kwargs):
if '...' in equation:
raise ValueError('Subscripts with ellipses are not yet supported.')
- match = re.match('([a-z,]+)(->[a-z]*)?', equation)
+ match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation)
if not match:
raise ValueError('Indices have incorrect format: %s' % equation)
@@ -402,7 +402,7 @@ def _exponential_space_einsum(equation, *inputs):
if '...' in equation:
raise ValueError('Subscripts with ellipses are not yet supported.')
- match = re.match('([a-z,]+)(->[a-z]*)?', equation)
+ match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation)
if not match:
raise ValueError('Indices have incorrect format: %s' % equation)
diff --git a/tensorflow/python/ops/special_math_ops_test.py b/tensorflow/python/ops/special_math_ops_test.py
index 2c212f4548..d7c3a7e8dc 100644
--- a/tensorflow/python/ops/special_math_ops_test.py
+++ b/tensorflow/python/ops/special_math_ops_test.py
@@ -192,6 +192,9 @@ class EinsumTest(test.TestCase):
'abc,cba',
'dba,ead,cad->bce',
'aef,fbc,dca->bde',
+ 'iJ,Jk->ik',
+ 'iJ,Ki->JK',
+ 'iJk,Jklm->Jk'
]
long_cases = [
@@ -208,6 +211,8 @@ class EinsumTest(test.TestCase):
'ijk ijk',
'ij.jk->ik',
'ij...,jk...->ik...',
+ 'ij,k ->kji',
+ 'ij,k-> kji',
# axis in output that does not exist
'ij,jk->im',
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index a52f325ddb..e9f1def48c 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -56,8 +56,6 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import saver as saver_lib
-FLAGS = None
-
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
@@ -256,25 +254,24 @@ def freeze_graph(input_graph,
checkpoint_version=checkpoint_version)
-def main(unused_args):
- if FLAGS.checkpoint_version == 1:
+def main(unused_args, flags):
+ if flags.checkpoint_version == 1:
checkpoint_version = saver_pb2.SaverDef.V1
- elif FLAGS.checkpoint_version == 2:
+ elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
print("Invalid checkpoint version (must be '1' or '2'): %d" %
- FLAGS.checkpoint_version)
+ flags.checkpoint_version)
return -1
- freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
- FLAGS.input_checkpoint, FLAGS.output_node_names,
- FLAGS.restore_op_name, FLAGS.filename_tensor_name,
- FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
- FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
- FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
- FLAGS.saved_model_tags, checkpoint_version)
-
+ freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
+ flags.input_checkpoint, flags.output_node_names,
+ flags.restore_op_name, flags.filename_tensor_name,
+ flags.output_graph, flags.clear_devices, flags.initializer_nodes,
+ flags.variable_names_whitelist, flags.variable_names_blacklist,
+ flags.input_meta_graph, flags.input_saved_model_dir,
+ flags.saved_model_tags, checkpoint_version)
-if __name__ == "__main__":
+def run_main():
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
@@ -376,5 +373,10 @@ if __name__ == "__main__":
separated by \',\'. For tag-set contains multiple tags, all tags \
must be passed in.\
""")
- FLAGS, unparsed = parser.parse_known_args()
- app.run(main=main, argv=[sys.argv[0]] + unparsed)
+ flags, unparsed = parser.parse_known_args()
+
+ my_main = lambda unused_args: main(unused_args, flags)
+ app.run(main=my_main, argv=[sys.argv[0]] + unparsed)
+
+if __name__ == '__main__':
+ run_main()
diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py
index dd876cbe7f..6504fbc107 100644
--- a/tensorflow/python/tools/inspect_checkpoint.py
+++ b/tensorflow/python/tools/inspect_checkpoint.py
@@ -30,7 +30,7 @@ FLAGS = None
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
- all_tensor_names):
+ all_tensor_names=False):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
@@ -139,7 +139,7 @@ if __name__ == "__main__":
const=True,
type="bool",
default=False,
- help="If True, print the values of all the tensors.")
+ help="If True, print the names and values of all the tensors.")
parser.add_argument(
"--all_tensor_names",
nargs="?",
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index b0e9e3e5ed..b88be4ae04 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -38,11 +38,15 @@ from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.debug.wrappers import local_cli_wrapper
+from tensorflow.python.framework import meta_graph as meta_graph_lib
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.platform import app # pylint: disable=unused-import
from tensorflow.python.saved_model import loader
from tensorflow.python.tools import saved_model_utils
+# Set of ops to blacklist.
+_OP_BLACKLIST = set(['WriteFile', 'ReadFile'])
+
def _show_tag_sets(saved_model_dir):
"""Prints the tag-sets stored in SavedModel directory.
@@ -242,6 +246,27 @@ def get_signature_def_map(saved_model_dir, tag_set):
return meta_graph.signature_def
+def scan_meta_graph_def(meta_graph_def):
+ """Scans meta_graph_def and reports if there are ops on blacklist.
+
+ Print ops if they are on black list, or print success if no blacklisted ops
+ found.
+
+ Args:
+ meta_graph_def: MetaGraphDef protocol buffer.
+ """
+ all_ops_set = set(
+ meta_graph_lib.ops_used_by_graph_def(meta_graph_def.graph_def))
+ blacklisted_ops = _OP_BLACKLIST & all_ops_set
+ if blacklisted_ops:
+ # TODO(yifeif): print more warnings
+ print('MetaGraph with tag set %s contains the following blacklisted ops:' %
+ meta_graph_def.meta_info_def.tags, blacklisted_ops)
+ else:
+ print('MetaGraph with tag set %s does not contain blacklisted ops.' %
+ meta_graph_def.meta_info_def.tags)
+
+
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
overwrite_flag, tf_debug=False):
@@ -609,6 +634,21 @@ def run(args):
args.overwrite, tf_debug=args.tf_debug)
+def scan(args):
+ """Function triggered by scan command.
+
+ Args:
+ args: A namespace parsed from command line.
+ """
+ if args.tag_set:
+ scan_meta_graph_def(
+ saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
+ else:
+ saved_model = reader.read_saved_model(args.dir)
+ for meta_graph_def in saved_model.meta_graphs:
+ scan_meta_graph_def(meta_graph_def)
+
+
def create_parser():
"""Creates a parser that parse the command line arguments.
@@ -730,6 +770,26 @@ def create_parser():
'SavedModel.')
parser_run.set_defaults(func=run)
+ # scan command
+ scan_msg = ('Usage example:\n'
+ 'To scan for blacklisted ops in SavedModel:\n'
+ '$saved_model_cli scan --dir /tmp/saved_model\n'
+ 'To scan a specific MetaGraph, pass in --tag_set\n')
+ parser_scan = subparsers.add_parser(
+ 'scan',
+ description=scan_msg,
+ formatter_class=argparse.RawTextHelpFormatter)
+ parser_scan.add_argument(
+ '--dir',
+ type=str,
+ required=True,
+ help='directory containing the SavedModel to execute')
+ parser_scan.add_argument(
+ '--tag_set',
+ type=str,
+ help='tag-set of graph in SavedModel to scan, separated by \',\'')
+ parser_scan.set_defaults(func=scan)
+
return parser
diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py
index f99c844845..eedc893a38 100644
--- a/tensorflow/python/tools/saved_model_cli_test.py
+++ b/tensorflow/python/tools/saved_model_cli_test.py
@@ -525,6 +525,28 @@ signature_def['serving_default']:
y_expected = np.array([[2.5], [3.0]])
self.assertAllClose(y_expected, y_actual)
+ def testScanCommand(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ args = self.parser.parse_args(['scan', '--dir', base_path])
+ with captured_output() as (out, _):
+ saved_model_cli.scan(args)
+ output = out.getvalue().strip()
+ self.assertTrue('does not contain blacklisted ops' in output)
+
+ def testScanCommandFoundBlacklistedOp(self):
+ self.parser = saved_model_cli.create_parser()
+ base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
+ args = self.parser.parse_args(
+ ['scan', '--dir', base_path, '--tag_set', 'serve'])
+ op_blacklist = saved_model_cli._OP_BLACKLIST
+ saved_model_cli._OP_BLACKLIST = set(['VariableV2'])
+ with captured_output() as (out, _):
+ saved_model_cli.scan(args)
+ saved_model_cli._OP_BLACKLIST = op_blacklist
+ output = out.getvalue().strip()
+ self.assertTrue('\'VariableV2\'' in output)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 5ef8bd9e9c..ba0d038475 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1135,8 +1135,9 @@ class Saver(object):
the proliferation of checkpoint files on disk:
* `max_to_keep` indicates the maximum number of recent checkpoint files to
- keep. As new files are created, older files are deleted. If None or 0,
- all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
+ keep. As new files are created, older files are deleted. If None or 0,
+ no checkpoints are deleted from the filesystem but only the last one is
+ kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent
checkpoint files are kept.)
* `keep_checkpoint_every_n_hours`: In addition to keeping the most recent