aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD36
-rw-r--r--tensorflow/python/__init__.py2
-rw-r--r--tensorflow/python/client/session_benchmark.py1
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py49
-rw-r--r--tensorflow/python/data/util/nest.py6
-rw-r--r--tensorflow/python/data/util/sparse.py2
-rw-r--r--tensorflow/python/debug/cli/tensor_format.py2
-rw-r--r--tensorflow/python/debug/lib/debug_data.py2
-rw-r--r--tensorflow/python/eager/execution_callbacks.py2
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py2
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py2
-rw-r--r--tensorflow/python/estimator/estimator.py3
-rw-r--r--tensorflow/python/estimator/run_config.py5
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional.py2
-rw-r--r--tensorflow/python/kernel_tests/BUILD13
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py21
-rw-r--r--tensorflow/python/kernel_tests/constant_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/decode_jpeg_op_test.py1
-rw-r--r--tensorflow/python/kernel_tests/io_ops_test.py2
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py16
-rw-r--r--tensorflow/python/kernel_tests/manip_ops_test.py138
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py1
-rw-r--r--tensorflow/python/kernel_tests/tensordot_op_test.py54
-rw-r--r--tensorflow/python/kernel_tests/topk_op_test.py2
-rw-r--r--tensorflow/python/layers/convolutional.py11
-rw-r--r--tensorflow/python/layers/utils.py2
-rw-r--r--tensorflow/python/ops/array_ops.py7
-rw-r--r--tensorflow/python/ops/functional_ops.py2
-rw-r--r--tensorflow/python/ops/gradients_impl.py1
-rw-r--r--tensorflow/python/ops/image_ops.py4
-rw-r--r--tensorflow/python/ops/image_ops_impl.py106
-rw-r--r--tensorflow/python/ops/image_ops_test.py112
-rw-r--r--tensorflow/python/ops/linalg_grad.py59
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py7
-rw-r--r--tensorflow/python/ops/manip_grad.py31
-rw-r--r--tensorflow/python/ops/manip_ops.py38
-rw-r--r--tensorflow/python/ops/math_ops.py10
-rw-r--r--tensorflow/python/ops/rnn.py6
-rw-r--r--tensorflow/python/ops/standard_ops.py73
-rw-r--r--tensorflow/python/saved_model/loader_impl.py9
-rw-r--r--tensorflow/python/tools/freeze_graph.py38
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py16
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py1
-rw-r--r--tensorflow/python/tools/optimize_for_inference_test.py92
-rw-r--r--tensorflow/python/tools/saved_model_cli.py3
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py4
-rw-r--r--tensorflow/python/training/input.py2
-rw-r--r--tensorflow/python/training/saver.py12
-rw-r--r--tensorflow/python/util/compat_internal.py34
50 files changed, 832 insertions, 228 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2b4d5b8e0f..f5cd7885e7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -76,6 +76,7 @@ py_library(
":layers",
":lib",
":list_ops",
+ ":manip_ops",
":math_ops",
":metrics",
":nn",
@@ -1424,6 +1425,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "manip_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "math_ops_gen",
visibility = [
"//learning/brain/google/python/ops:__pkg__",
@@ -1755,6 +1764,8 @@ py_library(
":linalg_grad",
":linalg_ops",
":logging_ops",
+ ":manip_grad",
+ ":manip_ops",
":math_grad",
":math_ops",
":platform",
@@ -1878,6 +1889,29 @@ py_library(
)
py_library(
+ name = "manip_grad",
+ srcs = ["ops/manip_grad.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":control_flow_ops",
+ ":framework_for_generated_wrappers",
+ ":manip_ops",
+ ],
+)
+
+py_library(
+ name = "manip_ops",
+ srcs = ["ops/manip_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dtypes",
+ ":framework_ops",
+ ":manip_ops_gen",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "logging_ops",
srcs = ["ops/logging_ops.py"],
srcs_version = "PY2AND3",
@@ -2339,6 +2373,8 @@ py_library(
":linalg_ops",
":logging_ops",
":lookup_ops",
+ ":manip_grad",
+ ":manip_ops",
":math_grad",
":math_ops",
":numerics",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index bc9ddec2a5..ea7604d30f 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -84,6 +84,7 @@ from tensorflow.python.feature_column import feature_column_lib as feature_colum
from tensorflow.python.layers import layers
from tensorflow.python.ops import bitwise_ops as bitwise
from tensorflow.python.ops import image_ops as image
+from tensorflow.python.ops import manip_ops as manip
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import sets
@@ -241,6 +242,7 @@ _allowed_symbols.extend([
'linalg',
'logging',
'losses',
+ 'manip',
'metrics',
'newaxis',
'nn',
diff --git a/tensorflow/python/client/session_benchmark.py b/tensorflow/python/client/session_benchmark.py
index 721bca91b7..da74855193 100644
--- a/tensorflow/python/client/session_benchmark.py
+++ b/tensorflow/python/client/session_benchmark.py
@@ -22,6 +22,7 @@ import time
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 5d318531d5..c4b7e4919b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -203,7 +203,7 @@ class Dataset(object):
tensors: A nested structure of tensors.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return TensorDataset(tensors)
@@ -216,7 +216,7 @@ class Dataset(object):
0th dimension.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return TensorSliceDataset(tensors)
@@ -229,7 +229,7 @@ class Dataset(object):
sparse_tensor: A `tf.SparseTensor`.
Returns:
- A `Dataset` of rank-(N-1) sparse tensors.
+ Dataset: A `Dataset` of rank-(N-1) sparse tensors.
"""
return SparseTensorSliceDataset(sparse_tensor)
@@ -315,7 +315,7 @@ class Dataset(object):
`generator`.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
if not callable(generator):
raise TypeError("`generator` must be callable.")
@@ -458,7 +458,7 @@ class Dataset(object):
len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
Returns:
- A `RangeDataset`.
+ Dataset: A `RangeDataset`.
Raises:
ValueError: if len(args) == 0.
@@ -502,7 +502,7 @@ class Dataset(object):
datasets: A nested structure of datasets.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return ZipDataset(datasets)
@@ -528,7 +528,7 @@ class Dataset(object):
dataset: `Dataset` to be concatenated.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return ConcatenateDataset(self, dataset)
@@ -540,7 +540,7 @@ class Dataset(object):
maximum number elements that will be buffered when prefetching.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return PrefetchDataset(self, buffer_size)
@@ -558,12 +558,14 @@ class Dataset(object):
- /path/to/dir/b.py
- /path/to/dir/c.py
+ NOTE: The order of the file names returned can be non-deterministic.
+
Args:
file_pattern: A string or scalar string `tf.Tensor`, representing
the filename pattern that will be matched.
Returns:
- A `Dataset` of strings corresponding to file names.
+ Dataset: A `Dataset` of strings corresponding to file names.
"""
return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern))
@@ -580,7 +582,7 @@ class Dataset(object):
indefinitely.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return RepeatDataset(self, count)
@@ -604,7 +606,7 @@ class Dataset(object):
iterated over. (Defaults to `True`.)
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
@@ -617,7 +619,7 @@ class Dataset(object):
If a filename is not provided, the dataset will be cached in memory.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return CacheDataset(self, filename)
@@ -631,7 +633,7 @@ class Dataset(object):
dataset, the new dataset will contain all elements of this dataset.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return TakeDataset(self, count)
@@ -646,7 +648,7 @@ class Dataset(object):
is -1, skips the entire dataset.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return SkipDataset(self, count)
@@ -693,7 +695,7 @@ class Dataset(object):
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
Raises:
ValueError: if `num_shards` or `index` are illegal values. Note: error
@@ -737,7 +739,7 @@ class Dataset(object):
consecutive elements of this dataset to combine in a single batch.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return BatchDataset(self, batch_size)
@@ -766,7 +768,7 @@ class Dataset(object):
the empty string for string types.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values)
@@ -782,7 +784,7 @@ class Dataset(object):
specified, elements will be processed sequentially.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
if num_parallel_calls is None:
return MapDataset(self, map_func)
@@ -798,7 +800,7 @@ class Dataset(object):
`Dataset`.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return FlatMapDataset(self, map_func)
@@ -867,7 +869,7 @@ class Dataset(object):
input element before cycling to another input element.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return InterleaveDataset(self, map_func, cycle_length, block_length)
@@ -880,7 +882,7 @@ class Dataset(object):
scalar `tf.bool` tensor.
Returns:
- A `Dataset`.
+ Dataset: A `Dataset`.
"""
return FilterDataset(self, predicate)
@@ -901,10 +903,11 @@ class Dataset(object):
Args:
transformation_func: A function that takes one `Dataset` argument and
- returns a `Dataset`.
+ returns a `Dataset`.
Returns:
- The `Dataset` returned by applying `transformation_func` to this dataset.
+ Dataset: The `Dataset` returned by applying `transformation_func` to this
+ dataset.
"""
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index e387e35740..e90ce3fb40 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -266,7 +266,7 @@ def map_structure(func, *structure, **check_types_dict):
and the return value will contain the results in the same structure.
Args:
- func: A callable that acceps as many arguments are there are structures.
+ func: A callable that accepts as many arguments are there are structures.
*structure: scalar, or tuple or list of constructed scalars and/or other
tuples/lists, or scalars. Note: numpy arrays are considered scalars.
**check_types_dict: only valid keyword argument is `check_types`. If set to
@@ -479,8 +479,8 @@ def map_structure_up_to(shallow_tree, func, *inputs):
The `inputs`, can be thought of as having the same structure as
`shallow_tree`, but with leaf nodes that are themselves tree structures.
- This function therefore will return something with the same base structure as
- `shallow_tree`.
+ This function, therefore, will return something with the same base structure
+ as `shallow_tree`.
Examples:
diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py
index 5ebcb4ea81..5e6d224709 100644
--- a/tensorflow/python/data/util/sparse.py
+++ b/tensorflow/python/data/util/sparse.py
@@ -141,7 +141,7 @@ def serialize_sparse_tensors(tensors):
tensors: a tensor structure to serialize.
Returns:
- `tensors` with any sparse tensors replaced by the their serialized version.
+ `tensors` with any sparse tensors replaced by their serialized version.
"""
ret = nest.pack_sequence_as(tensors, [
diff --git a/tensorflow/python/debug/cli/tensor_format.py b/tensorflow/python/debug/cli/tensor_format.py
index d4aea76d65..e0759a8bc1 100644
--- a/tensorflow/python/debug/cli/tensor_format.py
+++ b/tensorflow/python/debug/cli/tensor_format.py
@@ -535,7 +535,7 @@ def numeric_summary(tensor):
if not isinstance(tensor, np.ndarray) or not np.size(tensor):
return debugger_cli_common.RichTextLines([
"No numeric summary available due to empty tensor."])
- elif (np.issubdtype(tensor.dtype, np.float) or
+ elif (np.issubdtype(tensor.dtype, np.floating) or
np.issubdtype(tensor.dtype, np.complex) or
np.issubdtype(tensor.dtype, np.integer)):
counts = [
diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py
index c4b13a1045..8d355aa27f 100644
--- a/tensorflow/python/debug/lib/debug_data.py
+++ b/tensorflow/python/debug/lib/debug_data.py
@@ -222,7 +222,7 @@ def has_inf_or_nan(datum, tensor):
# Also return False for data types that cannot be represented as numpy
# arrays.
return False
- elif (np.issubdtype(tensor.dtype, np.float) or
+ elif (np.issubdtype(tensor.dtype, np.floating) or
np.issubdtype(tensor.dtype, np.complex) or
np.issubdtype(tensor.dtype, np.integer)):
return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor))
diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py
index 2f1654dda4..988442c971 100644
--- a/tensorflow/python/eager/execution_callbacks.py
+++ b/tensorflow/python/eager/execution_callbacks.py
@@ -153,7 +153,7 @@ def inf_nan_callback(op_type,
continue
numpy_dtype = output.dtype.as_numpy_dtype
- if (np.issubdtype(numpy_dtype, np.float) or
+ if (np.issubdtype(numpy_dtype, np.floating) or
np.issubdtype(numpy_dtype, np.complex) or
np.issubdtype(numpy_dtype, np.integer)):
try:
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index 2bdec69303..706575985f 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -877,7 +877,7 @@ class BaseDNNWarmStartingTest(object):
# Create a second DNNClassifier, warm-started from the first. Use a
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
- # accumulator values that change). Use a a new FeatureColumn with a
+ # accumulator values that change). Use a new FeatureColumn with a
# different vocabulary for occupation.
new_vocab_list = ['doctor', 'consultant', 'engineer']
new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index cccb9af4b2..3e9183cf1b 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -2003,7 +2003,7 @@ class BaseLinearWarmStartingTest(object):
# Create a second LinearClassifier, warm-started from the first. Use a
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
- # accumulator values that change). Use a a new FeatureColumn with a
+ # accumulator values that change). Use a new FeatureColumn with a
# different vocabulary for occupation.
new_vocab_list = ['doctor', 'consultant', 'engineer']
new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 6da890cd22..17fab3df4d 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -55,6 +55,7 @@ from tensorflow.python.training import saver
from tensorflow.python.training import training
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
+from tensorflow.python.util import compat_internal
from tensorflow.python.util import nest
@@ -179,7 +180,7 @@ class Estimator(object):
self._config = config
# Model directory.
- model_dir = compat.path_to_str(model_dir)
+ model_dir = compat_internal.path_to_str(model_dir)
if (model_dir is not None) and (self._config.model_dir is not None):
if model_dir != self._config.model_dir:
# TODO(alanyee): remove this suppression after it is no longer needed
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 61a537022b..0c636a8da1 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -27,7 +27,7 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
-from tensorflow.python.util import compat
+from tensorflow.python.util import compat_internal
_USE_DEFAULT = object()
@@ -444,7 +444,8 @@ class RunConfig(object):
if tf_config:
logging.info('TF_CONFIG environment variable: %s', tf_config)
- model_dir = _get_model_dir(tf_config, compat.path_to_str(model_dir))
+ model_dir = _get_model_dir(tf_config,
+ compat_internal.path_to_str(model_dir))
RunConfig._replace(
self,
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
index b2ad4c4b65..2ee0732775 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -563,7 +563,7 @@ class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer):
return dict(list(base_config.items()) + list(config.items()))
-class Conv3DTranspose(tf_convolutional_layers.Conv3D, Layer):
+class Conv3DTranspose(tf_convolutional_layers.Conv3DTranspose, Layer):
"""Transposed convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index c87b7652ad..3a6058054b 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1602,6 +1602,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "manip_ops_test",
+ size = "small",
+ srcs = ["manip_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:manip_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+cuda_py_test(
name = "matmul_op_test",
size = "small",
srcs = ["matmul_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index aae6d0a36e..7ec4624310 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -1162,6 +1162,27 @@ class InvertPermutationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1])
+class UnravelIndexTest(test_util.TensorFlowTestCase):
+
+ def testUnravelIndex(self):
+ with self.test_session():
+ for dtype in [dtypes.int32, dtypes.int64]:
+ indices_1 = constant_op.constant(1621, dtype=dtype)
+ dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_1 = array_ops.unravel_index(indices_1, dims_1)
+ self.assertAllEqual(out_1.eval(), [3, 1, 4, 1])
+
+ indices_2 = constant_op.constant([1621], dtype=dtype)
+ dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype)
+ out_2 = array_ops.unravel_index(indices_2, dims_2)
+ self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]])
+
+ indices_3 = constant_op.constant([22, 41, 37], dtype=dtype)
+ dims_3 = constant_op.constant([7, 6], dtype=dtype)
+ out_3 = array_ops.unravel_index(indices_3, dims_3)
+ self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]])
+
+
class GuaranteeConstOpTest(test_util.TensorFlowTestCase):
def testSimple(self):
diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py
index 030c690167..16e56349c4 100644
--- a/tensorflow/python/kernel_tests/constant_op_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_test.py
@@ -454,18 +454,19 @@ class ZerosLikeTest(test.TestCase):
def testZerosLikeCPU(self):
for dtype in [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int8,
- dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16, dtypes_lib.int32,
- dtypes_lib.int64, dtypes_lib.bool, dtypes_lib.complex64,
- dtypes_lib.complex128, dtypes_lib.string
+ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64,
+ dtypes_lib.int8, dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16,
+ dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.bool,
+ dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.string
]:
self._compareZeros(dtype, fully_defined_shape=False, use_gpu=False)
self._compareZeros(dtype, fully_defined_shape=True, use_gpu=False)
def testZerosLikeGPU(self):
for dtype in [
- dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32,
- dtypes_lib.bool, dtypes_lib.int64, dtypes_lib.string
+ dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64,
+ dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.complex64,
+ dtypes_lib.complex128, dtypes_lib.bool
]:
self._compareZeros(dtype, fully_defined_shape=False, use_gpu=True)
self._compareZeros(dtype, fully_defined_shape=True, use_gpu=True)
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 3e9bd3dade..edfb20d6a2 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -24,6 +24,7 @@ import time
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import layers
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import constant_op
@@ -519,7 +520,7 @@ class Conv2DTest(test.TestCase):
dilations=[2, 2],
padding="VALID")
- # TODO this currently fails.
+ # TODO(yzhwang): this currently fails.
# self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
# filter_in_sizes=[2, 2, 1, 1],
# strides=[4, 4], padding="SAME",
diff --git a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
index ead55cd03b..89fd26c544 100644
--- a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import time
+from six.moves import xrange
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py
index f91875c6f0..61944f7e31 100644
--- a/tensorflow/python/kernel_tests/io_ops_test.py
+++ b/tensorflow/python/kernel_tests/io_ops_test.py
@@ -1,4 +1,4 @@
-# -*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 00c6706593..197dbf44af 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -953,14 +953,14 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
# Compute the expected loss 'manually'.
total = np.zeros((batch_size,))
for b in range(batch_size):
- for i in range(dims):
- for j in range(dims):
+ for i in range(dims - 1):
+ for j in range(i + 1, dims):
x = self._predictions[b, i].item() - self._predictions[b, j].item()
y = self._labels[b, i].item() - self._labels[b, j].item()
diff = (x - y)
total[b] += (diff * diff)
- self._expected_losses = np.divide(total, 9.0)
+ self._expected_losses = np.divide(total, 3.0)
def testValueErrorThrownWhenWeightIsNone(self):
with self.test_session():
@@ -1059,8 +1059,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
[[4, 8, 12], [1, 2, 3], [4, 5, 6]],
[[8, 1, 3], [7, 8, 9], [10, 11, 12]],
])
- self._test_valid_weights(
- labels, predictions, expected_loss=122.22222)
+ self._test_valid_weights(labels, predictions, expected_loss=137.5)
def test3dWeightedScalar(self):
labels = np.array([
@@ -1073,8 +1072,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
])
weight = 3.0
self._test_valid_weights(
- labels, predictions, expected_loss=weight * 122.22222,
- weights=weight)
+ labels, predictions, expected_loss=weight * 137.5, weights=weight)
def _test_invalid_weights(
self, labels, predictions, weights=1.0):
@@ -1124,7 +1122,9 @@ class MeanPairwiseSquaredErrorTest(test.TestCase):
])
self._test_valid_weights(
# TODO(ptucker): This doesn't look right.
- labels, predictions, expected_loss=9 * 122.22222,
+ labels,
+ predictions,
+ expected_loss=9 * 137.5,
weights=np.ones((2, 3, 3)))
def testLossWithAllZeroBatchSpecificWeights(self):
diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py
new file mode 100644
index 0000000000..b8200ac0cb
--- /dev/null
+++ b/tensorflow/python/kernel_tests/manip_ops_test.py
@@ -0,0 +1,138 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for manip_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import manip_ops
+from tensorflow.python.platform import test as test_lib
+
+# pylint: disable=g-import-not-at-top
+try:
+ from distutils.version import StrictVersion as Version
+ # numpy.roll for multiple shifts was introduced in numpy version 1.12.0
+ NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0")
+except ImportError:
+ NP_ROLL_CAN_MULTISHIFT = False
+# pylint: enable=g-import-not-at-top
+
+
+class RollTest(test_util.TensorFlowTestCase):
+
+ def _testRoll(self, np_input, shift, axis):
+ expected_roll = np.roll(np_input, shift, axis)
+ with self.test_session():
+ roll = manip_ops.roll(np_input, shift, axis)
+ self.assertAllEqual(roll.eval(), expected_roll)
+
+ def _testGradient(self, np_input, shift, axis):
+ with self.test_session():
+ inx = constant_op.constant(np_input.tolist())
+ xs = list(np_input.shape)
+ y = manip_ops.roll(inx, shift, axis)
+ # Expected y's shape to be the same
+ ys = xs
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ inx, xs, y, ys, x_init_value=np_input)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5)
+
+ def _testAll(self, np_input, shift, axis):
+ self._testRoll(np_input, shift, axis)
+ if np_input.dtype == np.float32:
+ self._testGradient(np_input, shift, axis)
+
+ def testIntTypes(self):
+ for t in [np.int32, np.int64]:
+ self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(
+ np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3],
+ [0, 1, 2])
+ self._testAll(
+ np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2],
+ [1, 2, 3])
+
+ def testFloatTypes(self):
+ for t in [np.float32, np.float64]:
+ self._testAll(np.random.rand(5).astype(t), 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0])
+ self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2])
+
+ def testComplexTypes(self):
+ for t in [np.complex64, np.complex128]:
+ x = np.random.rand(4, 4).astype(t)
+ self._testAll(x + 1j * x, 2, 0)
+ if NP_ROLL_CAN_MULTISHIFT:
+ x = np.random.rand(2, 5).astype(t)
+ self._testAll(x + 1j * x, [1, 2], [1, 0])
+ x = np.random.rand(3, 2, 1, 1).astype(t)
+ self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2])
+
+ def testRollInputMustVectorHigherRaises(self):
+ tensor = 7
+ shift = 1
+ axis = 0
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "input must be 1-D or higher"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2], [3, 4]]
+ shift = 1
+ axis = [[0, 1]]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "axis must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftMustBeScalarOrVectorRaises(self):
+ tensor = [[1, 2], [3, 4]]
+ shift = [[0, 1]]
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift must be a scalar or a 1-D vector"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollShiftAndAxisMustBeSameSizeRaises(self):
+ tensor = [[1, 2], [3, 4]]
+ shift = [1]
+ axis = [0, 1]
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "shift and axis must have the same size"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+ def testRollAxisOutOfRangeRaises(self):
+ tensor = [1, 2]
+ shift = 1
+ axis = 1
+ with self.test_session():
+ with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
+ "is out of range"):
+ manip_ops.roll(tensor, shift, axis).eval()
+
+
+if __name__ == "__main__":
+ test_lib.main()
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 0c77d1db92..daa42938e6 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -23,6 +23,7 @@ import timeit
import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import rnn as contrib_rnn
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py
index f1670a47f5..8ad29afd0a 100644
--- a/tensorflow/python/kernel_tests/tensordot_op_test.py
+++ b/tensorflow/python/kernel_tests/tensordot_op_test.py
@@ -66,7 +66,7 @@ class TensordotTest(test_lib.TestCase):
a = [[1, 2], [3, 4]]
b = [[1, 2], [3, 4]]
# Invalid static axes.
- for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]:
+ for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]:
with self.assertRaises(ValueError):
math_ops.tensordot(a, b, axes_value)
@@ -91,7 +91,7 @@ class TensordotTest(test_lib.TestCase):
# Test case for 11950
def test_valid_axis(self):
- for axes_value in [1, 2], [[1], [2]]:
+ for axes_value in [1, 2], [[1], [2]], [[], []], 0:
with self.test_session() as sess:
np_a = np.ones((3, 3))
np_b = np.array([2, 3, 1])[None, None]
@@ -105,29 +105,29 @@ class TensordotTest(test_lib.TestCase):
self.assertAllEqual(tf_ans, np_ans)
def test_partial_shape_inference(self):
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- axes = ([1], [0])
- output = math_ops.tensordot(a, b, axes)
- self.assertEqual(output.get_shape().ndims, None)
- a.set_shape([None, 2])
- b.set_shape([2, 3])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], None)
- self.assertEqual(output_shape[1], 3)
- a = array_ops.placeholder(dtypes.float32)
- b = array_ops.placeholder(dtypes.float32)
- a.set_shape([2, 2])
- b.set_shape([2, None])
- output = math_ops.tensordot(a, b, axes)
- output_shape = output.get_shape()
- self.assertEqual(output_shape.ndims, 2)
- output_shape = output_shape.as_list()
- self.assertEqual(output_shape[0], 2)
- self.assertEqual(output_shape[1], None)
+ for axes in ([1], [0]), 1:
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ output = math_ops.tensordot(a, b, axes)
+ self.assertEqual(output.get_shape().ndims, None)
+ a.set_shape([None, 2])
+ b.set_shape([2, 3])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], None)
+ self.assertEqual(output_shape[1], 3)
+ a = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ a.set_shape([2, 2])
+ b.set_shape([2, None])
+ output = math_ops.tensordot(a, b, axes)
+ output_shape = output.get_shape()
+ self.assertEqual(output_shape.ndims, 2)
+ output_shape = output_shape.as_list()
+ self.assertEqual(output_shape[0], 2)
+ self.assertEqual(output_shape[1], None)
def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
@@ -196,8 +196,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_):
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
b_np = np.random.uniform(
low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_)
- all_axes = [1]
- if a_np.ndim > 1:
+ all_axes = [0, 1]
+ if a_np.ndim > 2:
all_axes.append(a_np.ndim - 1)
for axes in all_axes:
np_ans = np.tensordot(a_np, b_np, axes=axes)
diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py
index efb5b9f364..6ab931fdb9 100644
--- a/tensorflow/python/kernel_tests/topk_op_test.py
+++ b/tensorflow/python/kernel_tests/topk_op_test.py
@@ -58,7 +58,7 @@ class TopKTest(test.TestCase):
# Do some special casing of equality of indices: if indices
# are not the same, but values are floating type, ensure that
# the values are within epsilon of each other.
- if not np.issubdtype(np_expected_values.dtype, np.float):
+ if not np.issubdtype(np_expected_values.dtype, np.floating):
# Values are not floating point type; check indices exactly
self.assertAllEqual(np_expected_indices, indices)
else:
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 79c421f4c9..e8dba3cea3 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -1094,7 +1094,7 @@ class SeparableConv1D(_SeparableConv):
strides = (1, 1, 1) + self.strides
spatial_start_dim = 2
- # Explictly broadcast inputs and kernels to 4D.
+ # Explicitly broadcast inputs and kernels to 4D.
# TODO(fchollet): refactor when a native separable_conv1d op is available.
inputs = array_ops.expand_dims(inputs, spatial_start_dim)
depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0)
@@ -1904,6 +1904,7 @@ class Conv3DTranspose(Conv3D):
dtype=self.dtype)
else:
self.bias = None
+ self.built = True
def call(self, inputs):
inputs_shape = array_ops.shape(inputs)
@@ -1974,6 +1975,8 @@ class Conv3DTranspose(Conv3D):
if self.use_bias:
outputs_shape = outputs.shape.as_list()
+ if outputs_shape[0] is None:
+ outputs_shape[0] = -1
if self.data_format == 'channels_first':
outputs_4d = array_ops.reshape(outputs, [
outputs_shape[0], outputs_shape[1],
@@ -2007,11 +2010,11 @@ class Conv3DTranspose(Conv3D):
output_shape[c_axis] = self.filters
output_shape[d_axis] = utils.deconv_output_length(
- output_shape[d_axis], stride_d, kernel_d, self.padding)
+ output_shape[d_axis], kernel_d, self.padding, stride_d)
output_shape[h_axis] = utils.deconv_output_length(
- output_shape[h_axis], stride_h, kernel_h, self.padding)
+ output_shape[h_axis], kernel_h, self.padding, stride_h)
output_shape[w_axis] = utils.deconv_output_length(
- output_shape[w_axis], stride_w, kernel_w, self.padding)
+ output_shape[w_axis], kernel_w, self.padding, stride_w)
return tensor_shape.TensorShape(output_shape)
diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py
index e8be347799..7407d9a7b3 100644
--- a/tensorflow/python/layers/utils.py
+++ b/tensorflow/python/layers/utils.py
@@ -81,7 +81,7 @@ def normalize_tuple(value, n, name):
for single_value in value_tuple:
try:
int(single_value)
- except ValueError:
+ except (ValueError, TypeError):
raise ValueError('The `' + name + '` argument must be a tuple of ' +
str(n) + ' integers. Received: ' + str(value) + ' '
'including element ' + str(single_value) + ' of type' +
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index e3902f5a8a..ad409ad7e5 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -35,6 +35,7 @@ See the @{$python/array_ops} guide.
@@reshape
@@squeeze
@@expand_dims
+@@unravel_index
@@meshgrid
@@slice
@@strided_slice
@@ -1589,9 +1590,9 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
- dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
- `complex64`, `complex128` or `bool`.
+ dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
+ `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
+ `complex64`, `complex128`, `bool` or `string`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 7dbccf1caf..ac03d30fcd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -458,7 +458,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
`[i1, i2]` then an appropriate signature for `fn` in `python2` is:
- `fn = lambda (acc_p1, acc_p2), (t1 [t2, t3]):` and `fn` must return a list,
+ `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
`[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the
one that works in `python3`, is:
`fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 28b26a09a5..9f06c0ee1f 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import image_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import
from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
+from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 3b0b5a978c..de12c5f63f 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -49,6 +49,10 @@ See the @{$python/image} guide.
@@grayscale_to_rgb
@@hsv_to_rgb
@@rgb_to_hsv
+@@rgb_to_yiq
+@@yiq_to_rgb
+@@rgb_to_yuv
+@@yuv_to_rgb
@@convert_image_dtype
@@adjust_brightness
@@random_brightness
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 2c231ef56c..14a38f25d1 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -1508,7 +1508,7 @@ def sample_distorted_bounding_box(image_size,
bounding_boxes,
seed=None,
seed2=None,
- min_object_covered=None,
+ min_object_covered=0.1,
aspect_ratio_range=None,
area_range=None,
max_attempts=None,
@@ -1669,3 +1669,107 @@ def non_max_suppression(boxes,
return gen_image_ops._non_max_suppression_v2(boxes, scores, max_output_size,
iou_threshold)
# pylint: enable=protected-access
+
+
+_rgb_to_yiq_kernel = [[0.299, 0.59590059,
+ 0.2115], [0.587, -0.27455667, -0.52273617],
+ [0.114, -0.32134392, 0.31119955]]
+
+
+def rgb_to_yiq(images):
+ """Converts one or more images from RGB to YIQ.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the YIQ
+ value of the pixels.
+ The output is only well defined if the value in images are in [0,1].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _rgb_to_yiq_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
+
+
+_yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021],
+ [0.6208248, -0.64720424, 1.70423049]]
+
+
+def yiq_to_rgb(images):
+ """Converts one or more images from YIQ to RGB.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+ value of the pixels.
+ The output is only well defined if the Y value in images are in [0,1],
+ I value are in [-0.5957,0.5957] and Q value are in [-0.5226,0.5226].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _yiq_to_rgb_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
+
+
+_rgb_to_yuv_kernel = [[0.299, -0.14714119,
+ 0.61497538], [0.587, -0.28886916, -0.51496512],
+ [0.114, 0.43601035, -0.10001026]]
+
+
+def rgb_to_yuv(images):
+ """Converts one or more images from RGB to YUV.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the YUV
+ value of the pixels.
+ The output is only well defined if the value in images are in [0,1].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _rgb_to_yuv_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
+
+
+_yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185],
+ [1.13988303, -0.58062185, 0]]
+
+
+def yuv_to_rgb(images):
+ """Converts one or more images from YUV to RGB.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+ value of the pixels.
+ The output is only well defined if the Y value in images are in [0,1],
+ U and V value are in [-0.5,0.5].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 47dd8231c0..b12bd3d5b0 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -85,6 +85,64 @@ class RGBToHSVTest(test_util.TensorFlowTestCase):
self.assertAllClose(rgb_tf, rgb_np)
+class RGBToYIQTest(test_util.TensorFlowTestCase):
+
+ def testBatch(self):
+ # Build an arbitrary RGB image
+ np.random.seed(7)
+ batch_size = 5
+ shape = (batch_size, 2, 7, 3)
+
+ for nptype in [np.float32, np.float64]:
+ inp = np.random.rand(*shape).astype(nptype)
+
+ # Convert to YIQ and back, as a batch and individually
+ with self.test_session(use_gpu=True) as sess:
+ batch0 = constant_op.constant(inp)
+ batch1 = image_ops.rgb_to_yiq(batch0)
+ batch2 = image_ops.yiq_to_rgb(batch1)
+ split0 = array_ops.unstack(batch0)
+ split1 = list(map(image_ops.rgb_to_yiq, split0))
+ split2 = list(map(image_ops.yiq_to_rgb, split1))
+ join1 = array_ops.stack(split1)
+ join2 = array_ops.stack(split2)
+ batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+
+ # Verify that processing batch elements together is the same as separate
+ self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4)
+
+
+class RGBToYUVTest(test_util.TensorFlowTestCase):
+
+ def testBatch(self):
+ # Build an arbitrary RGB image
+ np.random.seed(7)
+ batch_size = 5
+ shape = (batch_size, 2, 7, 3)
+
+ for nptype in [np.float32, np.float64]:
+ inp = np.random.rand(*shape).astype(nptype)
+
+ # Convert to YUV and back, as a batch and individually
+ with self.test_session(use_gpu=True) as sess:
+ batch0 = constant_op.constant(inp)
+ batch1 = image_ops.rgb_to_yuv(batch0)
+ batch2 = image_ops.yuv_to_rgb(batch1)
+ split0 = array_ops.unstack(batch0)
+ split1 = list(map(image_ops.rgb_to_yuv, split0))
+ split2 = list(map(image_ops.yuv_to_rgb, split1))
+ join1 = array_ops.stack(split1)
+ join2 = array_ops.stack(split2)
+ batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+
+ # Verify that processing batch elements together is the same as separate
+ self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4)
+
+
class GrayscaleToRGBTest(test_util.TensorFlowTestCase):
def _RGBToGrayscale(self, images):
@@ -1839,6 +1897,26 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase):
self.assertAllEqual([3], end.get_shape().as_list())
self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list())
+ def testDefaultMinObjectCovered(self):
+ # By default min_object_covered=0.1 if not provided
+ with self.test_session(use_gpu=True):
+ image_size = constant_op.constant(
+ [40, 50, 1], shape=[3], dtype=dtypes.int32)
+ bounding_box = constant_op.constant(
+ [0.0, 0.0, 1.0, 1.0],
+ shape=[4],
+ dtype=dtypes.float32,
+ )
+ begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box(
+ image_size=image_size,
+ bounding_boxes=bounding_box,
+ aspect_ratio_range=(0.75, 1.33),
+ area_range=(0.05, 1.0))
+
+ self.assertAllEqual([3], begin.get_shape().as_list())
+ self.assertAllEqual([3], end.get_shape().as_list())
+ self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list())
+
class ResizeImagesTest(test_util.TensorFlowTestCase):
@@ -3092,6 +3170,40 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
boxes, scores, max_output_size, iou_threshold).eval()
self.assertAllClose(selected_indices, [3, 0, 5])
+ def testInvalidShape(self):
+ # The boxes should be 2D of shape [num_boxes, 4].
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 2 but is rank 1"):
+ boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
+ with self.assertRaisesRegexp(ValueError, "Dimension must be 4 but is 3"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0]])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
+ # The scores should be 1D of shape [num_boxes].
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 1 but is rank 2"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([[0.9]])
+ image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
+ # The max_output_size should be a scaler (0-D).
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 0 but is rank 1"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, [3], 0.5)
+
+ # The iou_threshold should be a scaler (0-D).
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 0 but is rank 2"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 13a32c83d9..3cbbf3412a 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -277,20 +277,28 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
# https://j-towns.github.io/papers/svd-derivative.pdf
a = op.inputs[0]
a_shape = a.get_shape().with_rank_at_least(2)
+ grad_s_mat = array_ops.matrix_diag(grad_s)
- if op.get_attr("compute_uv"):
- # TODO(rmlarsen): Make this work with complex types.
- if a.dtype.is_complex:
- raise NotImplementedError(
- "SVD gradient is not implemented for complex types and "
- "compute_uv=True.")
- grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
- grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
- m = a_shape[-2].merge_with(grad_u_shape[-2])
- n = a_shape[-1].merge_with(grad_v_shape[-2])
- batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
- grad_v_shape[:-2])
- a_shape = batch_shape.concatenate([m, n])
+ if not op.get_attr("compute_uv"):
+ s, u, v = linalg_ops.svd(a, compute_uv=True)
+ grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
+ grad_a.set_shape(a_shape)
+ return grad_a
+
+ full_matrices = op.get_attr("full_matrices")
+
+ # TODO(rmlarsen): Make this work with complex types.
+ if a.dtype.is_complex:
+ raise NotImplementedError(
+ "SVD gradient is not implemented for complex types and "
+ "compute_uv=True.")
+ grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
+ grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
+ m = a_shape[-2].merge_with(grad_u_shape[-2])
+ n = a_shape[-1].merge_with(grad_v_shape[-2])
+ batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
+ grad_v_shape[:-2])
+ a_shape = batch_shape.concatenate([m, n])
m = a_shape[-2].value
n = a_shape[-1].value
@@ -300,12 +308,9 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
"SVD gradient has not been implemented for input with unknown "
"inner matrix shape.")
- if not op.get_attr("compute_uv"):
- s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True)
- else:
- s = op.outputs[0]
- u = op.outputs[1]
- v = op.outputs[2]
+ s = op.outputs[0]
+ u = op.outputs[1]
+ v = op.outputs[2]
use_adjoint = False
if m > n:
@@ -317,19 +322,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
grad_u, grad_v = grad_v, grad_u
with ops.control_dependencies([grad_s, grad_u, grad_v]):
- grad_s_mat = array_ops.matrix_diag(grad_s)
- if not op.get_attr("compute_uv"):
- if use_adjoint:
- grad_a = math_ops.matmul(
- v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True)
- else:
- grad_a = math_ops.matmul(u,
- math_ops.matmul(
- grad_s_mat, v[..., :, :m], adjoint_b=True))
- grad_a.set_shape(a_shape)
- return grad_a
-
- if op.get_attr("full_matrices") and abs(m - n) > 1:
+ if full_matrices and abs(m - n) > 1:
raise NotImplementedError(
"svd gradient is not implemented for abs(m - n) > 1 "
"when full_matrices is True")
@@ -371,7 +364,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
gv1t_v1 = math_ops.matmul(gv1t, v1)
term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
- if op.get_attr("full_matrices"):
+ if full_matrices:
v2 = v[..., :, m:n]
grad_v2 = grad_v[..., :, m:n]
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index e75a9b22e4..84afbf0627 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -547,12 +547,13 @@ def mean_pairwise_squared_error(
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
- num_present_per_batch)
+ num_present_per_batch - 1)
sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keep_dims=True)
- term2 = 2.0 * _safe_div(math_ops.square(sum_diff),
- math_ops.square(num_present_per_batch))
+ term2 = 2.0 * _safe_div(
+ math_ops.square(sum_diff),
+ math_ops.multiply(num_present_per_batch, num_present_per_batch - 1))
weighted_losses = math_ops.multiply(term1 - term2, weights)
loss = math_ops.reduce_sum(weighted_losses)
diff --git a/tensorflow/python/ops/manip_grad.py b/tensorflow/python/ops/manip_grad.py
new file mode 100644
index 0000000000..bb2069359d
--- /dev/null
+++ b/tensorflow/python/ops/manip_grad.py
@@ -0,0 +1,31 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for operators defined in manip_ops.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import manip_ops
+
+
+@ops.RegisterGradient("Roll")
+def _RollGrad(op, grad):
+ # The gradient is just the roll reversed
+ shift = op.inputs[1]
+ axis = op.inputs[2]
+ roll_grad = manip_ops.roll(grad, -shift, axis)
+ return roll_grad, None, None
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
new file mode 100644
index 0000000000..91e15b47b9
--- /dev/null
+++ b/tensorflow/python/ops/manip_ops.py
@@ -0,0 +1,38 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators for manipulating tensors.
+
+@@roll
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
+from tensorflow.python.util.all_util import remove_undocumented
+
+
+# pylint: disable=protected-access
+def roll(input, shift, axis): # pylint: disable=redefined-builtin
+ return _gen_manip_ops.roll(input, shift, axis)
+
+
+roll.__doc__ = _gen_manip_ops.roll.__doc__
+# pylint: enable=protected-access
+
+_allowed_symbols = ['roll']
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 827e3caa36..9a8ac93de9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2826,10 +2826,14 @@ def tensordot(a, b, axes, name=None):
"""Generates two sets of contraction axes for the two tensor arguments."""
a_shape = a.get_shape()
if isinstance(axes, compat.integral_types):
- if axes < 1:
- raise ValueError("'axes' must be at least 1.")
+ if axes < 0:
+ raise ValueError("'axes' must be at least 0.")
if a_shape.ndims is not None:
- return range(a_shape.ndims - axes, a_shape.ndims), range(axes)
+ if axes > a_shape.ndims:
+ raise ValueError("'axes' must not be larger than the number of "
+ "dimensions of tensor %s." % a)
+ return (list(xrange(a_shape.ndims - axes, a_shape.ndims)),
+ list(xrange(axes)))
else:
rank = array_ops.rank(a)
return (range(rank - axes, rank, dtype=dtypes.int32),
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 24c6f64f0a..da80e72071 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -1127,6 +1127,12 @@ def raw_rnn(cell, loop_fn,
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
def copy_fn(cur_i, cand_i):
+ # TensorArray and scalar get passed through.
+ if isinstance(cur_i, tensor_array_ops.TensorArray):
+ return cand_i
+ if cur_i.shape.ndims == 0:
+ return cand_i
+ # Otherwise propagate the old or the new value.
with ops.colocate_with(cand_i):
return array_ops.where(elements_finished, cur_i, cand_i)
return nest.map_structure(copy_fn, current, candidate)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 30bf4e4ef1..009d1dc3b9 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -25,6 +25,7 @@ import sys as _sys
# Imports the following modules so that @RegisterGradient get executed.
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import data_flow_grad
+from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import math_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
@@ -42,11 +43,13 @@ from tensorflow.python.ops.special_math_ops import *
# TODO(vrv): Switch to import * once we're okay with exposing the module.
from tensorflow.python.ops.confusion_matrix import confusion_matrix
from tensorflow.python.ops.control_flow_ops import Assert
+from tensorflow.python.ops.control_flow_ops import case
+from tensorflow.python.ops.control_flow_ops import cond
from tensorflow.python.ops.control_flow_ops import group
from tensorflow.python.ops.control_flow_ops import no_op
+# pylint: disable=redefined-builtin
from tensorflow.python.ops.control_flow_ops import tuple
-from tensorflow.python.ops.control_flow_ops import cond
-from tensorflow.python.ops.control_flow_ops import case
+# pylint: enable=redefined-builtin
from tensorflow.python.ops.control_flow_ops import while_loop
from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.functional_ops import *
@@ -59,6 +62,7 @@ from tensorflow.python.ops.logging_ops import Print
from tensorflow.python.ops.logging_ops import get_summary_op
from tensorflow.python.ops.lookup_ops import initialize_all_tables
from tensorflow.python.ops.lookup_ops import tables_initializer
+from tensorflow.python.ops.manip_ops import *
from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
@@ -105,6 +109,7 @@ from tensorflow.python.ops import init_ops as _init_ops
from tensorflow.python.ops import io_ops as _io_ops
from tensorflow.python.ops import linalg_ops as _linalg_ops
from tensorflow.python.ops import logging_ops as _logging_ops
+from tensorflow.python.ops import manip_ops as _manip_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.ops import numerics as _numerics
from tensorflow.python.ops import parsing_ops as _parsing_ops
@@ -264,34 +269,36 @@ _allowed_symbols = (_allowed_symbols_array_ops +
_allowed_symbols_misc +
_allowed_symbols_partitioned_variables)
-remove_undocumented(__name__, _allowed_symbols,
- [_sys.modules[__name__],
- _array_ops,
- _check_ops,
- _clip_ops,
- _confusion_matrix,
- _control_flow_ops,
- _constant_op,
- _data_flow_ops,
- _functional_ops,
- _gradients,
- _histogram_ops,
- _init_ops,
- _io_ops,
- _linalg_ops,
- _logging_ops,
- _math_ops,
- _numerics,
- _parsing_ops,
- _partitioned_variables,
- _random_ops,
- _script_ops,
- _session_ops,
- _sparse_ops,
- _special_math_ops,
- _state_ops,
- _string_ops,
- _template,
- _tensor_array_ops,
- _variable_scope,
- _variables,])
+remove_undocumented(__name__, _allowed_symbols, [
+ _sys.modules[__name__],
+ _array_ops,
+ _check_ops,
+ _clip_ops,
+ _confusion_matrix,
+ _control_flow_ops,
+ _constant_op,
+ _data_flow_ops,
+ _functional_ops,
+ _gradients,
+ _histogram_ops,
+ _init_ops,
+ _io_ops,
+ _linalg_ops,
+ _logging_ops,
+ _manip_ops,
+ _math_ops,
+ _numerics,
+ _parsing_ops,
+ _partitioned_variables,
+ _random_ops,
+ _script_ops,
+ _session_ops,
+ _sparse_ops,
+ _special_math_ops,
+ _state_ops,
+ _string_ops,
+ _template,
+ _tensor_array_ops,
+ _variable_scope,
+ _variables,
+])
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index ddfd6be6da..bebf1d5e0d 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -235,13 +235,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
asset_tensors_dictionary = _get_asset_tensors(export_dir,
meta_graph_def_to_load)
- main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
+ main_op_tensor = (
+ _get_main_op_tensor(meta_graph_def_to_load) or
+ (_get_legacy_init_op_tensor(meta_graph_def_to_load)))
if main_op_tensor is not None:
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
- else:
- legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
- if legacy_init_op_tensor is not None:
- sess.run(
- fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary)
return meta_graph_def_to_load
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index 0ddf09260b..affa97062a 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -72,7 +72,8 @@ def freeze_graph_with_def_protos(input_graph_def,
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
- saved_model_tags=None):
+ saved_model_tags=None,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
@@ -100,7 +101,8 @@ def freeze_graph_with_def_protos(input_graph_def,
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
- saver = saver_lib.Saver(saver_def=input_saver_def)
+ saver = saver_lib.Saver(
+ saver_def=input_saver_def, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
@@ -124,7 +126,8 @@ def freeze_graph_with_def_protos(input_graph_def,
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
- saver = saver_lib.Saver(var_list=var_list)
+ saver = saver_lib.Saver(
+ var_list=var_list, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
@@ -217,7 +220,8 @@ def freeze_graph(input_graph,
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
- saved_model_tags=tag_constants.SERVING):
+ saved_model_tags=tag_constants.SERVING,
+ checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
input_graph_def = None
if input_saved_model_dir:
@@ -233,10 +237,21 @@ def freeze_graph(input_graph,
if input_saver:
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
freeze_graph_with_def_protos(
- input_graph_def, input_saver_def, input_checkpoint, output_node_names,
- restore_op_name, filename_tensor_name, output_graph, clear_devices,
- initializer_nodes, variable_names_whitelist, variable_names_blacklist,
- input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(","))
+ input_graph_def,
+ input_saver_def,
+ input_checkpoint,
+ output_node_names,
+ restore_op_name,
+ filename_tensor_name,
+ output_graph,
+ clear_devices,
+ initializer_nodes,
+ variable_names_whitelist,
+ variable_names_blacklist,
+ input_meta_graph_def,
+ input_saved_model_dir,
+ saved_model_tags.split(","),
+ checkpoint_version=checkpoint_version)
def main(unused_args):
@@ -246,7 +261,7 @@ def main(unused_args):
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
- FLAGS.saved_model_tags)
+ FLAGS.saved_model_tags, FLAGS.checkpoint_version)
if __name__ == "__main__":
@@ -268,6 +283,11 @@ if __name__ == "__main__":
default="",
help="TensorFlow variables file to load.")
parser.add_argument(
+ "--checkpoint_version",
+ type=int,
+ default=saver_pb2.SaverDef.V2,
+ help="Tensorflow variable file format")
+ parser.add_argument(
"--output_graph",
type=str,
default="",
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index feeed7102c..91f0061ebc 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -84,9 +84,19 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
input_meta_graph = checkpoint_meta_graph_file
freeze_graph.freeze_graph(
- input_graph_path, input_saver_def_path, input_binary, checkpoint_path,
- output_node_names, restore_op_name, filename_tensor_name,
- output_graph_path, clear_devices, "", "", input_meta_graph)
+ input_graph_path,
+ input_saver_def_path,
+ input_binary,
+ checkpoint_path,
+ output_node_names,
+ restore_op_name,
+ filename_tensor_name,
+ output_graph_path,
+ clear_devices,
+ "",
+ "",
+ input_meta_graph,
+ checkpoint_version=saver_write_version)
# Now we make sure the variable is now a constant, and that the graph still
# produces the expected result.
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index c2687bf557..9c19271222 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def):
bias_add_op.op = "BiasAdd"
bias_add_op.name = node.name
bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"])
+ bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"])
bias_add_op.input.extend([new_conv_op.name, offset_op.name])
new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op])
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index 7686bb0f14..084a4500f8 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -173,48 +173,56 @@ class OptimizeForInferenceTest(test.TestCase):
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
def testFoldFusedBatchNorms(self):
- with self.test_session() as sess:
- inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
- input_op = constant_op.constant(
- np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32)
- weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
- weights_op = constant_op.constant(
- np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
- conv_op = nn_ops.conv2d(
- input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op")
- mean_op = constant_op.constant(
- np.array([10, 20]), shape=[2], dtype=dtypes.float32)
- variance_op = constant_op.constant(
- np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
- beta_op = constant_op.constant(
- np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
- gamma_op = constant_op.constant(
- np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
- ops.get_default_graph().graph_def_versions.producer = 9
- gen_nn_ops._fused_batch_norm(
- conv_op,
- gamma_op,
- beta_op,
- mean_op,
- variance_op,
- 0.00001,
- is_training=False,
- name="output")
- original_graph_def = sess.graph_def
- original_result = sess.run(["output:0"])
- optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
- original_graph_def)
-
- with self.test_session() as sess:
- _ = importer.import_graph_def(
- optimized_graph_def, input_map={}, name="optimized")
- optimized_result = sess.run(["optimized/output:0"])
-
- self.assertAllClose(
- original_result, optimized_result, rtol=1e-04, atol=1e-06)
-
- for node in optimized_graph_def.node:
- self.assertNotEqual("FusedBatchNorm", node.op)
+ for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+ input_op = constant_op.constant(
+ np.array(inputs),
+ shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6],
+ dtype=dtypes.float32)
+ weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+ weights_op = constant_op.constant(
+ np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32)
+ conv_op = nn_ops.conv2d(
+ input_op,
+ weights_op, [1, 1, 1, 1],
+ padding="SAME",
+ data_format=data_format,
+ name="conv_op")
+ mean_op = constant_op.constant(
+ np.array([10, 20]), shape=[2], dtype=dtypes.float32)
+ variance_op = constant_op.constant(
+ np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32)
+ beta_op = constant_op.constant(
+ np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32)
+ gamma_op = constant_op.constant(
+ np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32)
+ ops.get_default_graph().graph_def_versions.producer = 9
+ gen_nn_ops._fused_batch_norm(
+ conv_op,
+ gamma_op,
+ beta_op,
+ mean_op,
+ variance_op,
+ 0.00001,
+ is_training=False,
+ data_format=data_format,
+ name="output")
+ original_graph_def = sess.graph_def
+ original_result = sess.run(["output:0"])
+ optimized_graph_def = optimize_for_inference_lib.fold_batch_norms(
+ original_graph_def)
+
+ with self.test_session(use_gpu=use_gpu) as sess:
+ _ = importer.import_graph_def(
+ optimized_graph_def, input_map={}, name="optimized")
+ optimized_result = sess.run(["optimized/output:0"])
+
+ self.assertAllClose(
+ original_result, optimized_result, rtol=1e-04, atol=1e-06)
+
+ for node in optimized_graph_def.node:
+ self.assertNotEqual("FusedBatchNorm", node.op)
def testFuseResizePadAndConv(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 667a4b1db8..33f6debbcb 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -31,6 +31,7 @@ import warnings
import numpy as np
+from six import integer_types
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.example import example_pb2
@@ -440,7 +441,7 @@ def _create_example_string(example_dict):
elif isinstance(feature_list[0], str):
example.features.feature[feature_name].bytes_list.value.extend(
feature_list)
- elif isinstance(feature_list[0], (int, long)):
+ elif isinstance(feature_list[0], integer_types):
example.features.feature[feature_name].int64_list.value.extend(
feature_list)
else:
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 17e07e171a..aae757b99a 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -336,7 +336,7 @@ class CheckpointSaverListener(object):
`CheckpointSaverHook`, as in this example:
```python
- class ExampleCheckpointSaverListerner(CheckpointSaverListener):
+ class ExampleCheckpointSaverListener(CheckpointSaverListener):
def begin(self):
# You can add ops to the graph here.
print('Starting the session.')
@@ -352,7 +352,7 @@ class CheckpointSaverListener(object):
print('Done with the session.')
...
- listener = ExampleCheckpointSaverListerner()
+ listener = ExampleCheckpointSaverListener()
saver_hook = tf.train.CheckpointSaverHook(
checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]):
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 992184ec9e..bd9985a7c5 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -58,6 +58,8 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map
def match_filenames_once(pattern, name=None):
"""Save the list of files matching pattern, so it is only computed once.
+ NOTE: The order of the files returned can be non-deterministic.
+
Args:
pattern: A file pattern (glob), or 1D tensor of file patterns.
name: A name for the operations (optional).
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index 3888e9bba4..0c1c8e664b 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1597,9 +1597,9 @@ class Saver(object):
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
- A string: path prefix used for the checkpoint files. If the saver is
- sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
- is the number of shards created.
+ A string: path prefix used for the checkpoint files. If checkpoint
+ format is V1 and the saver is sharded, this string ends with:
+ '-?????-of-nnnnn' where 'nnnnn' is the number of shards created.
If the saver is empty, returns None.
Raises:
@@ -1749,6 +1749,12 @@ class Saver(object):
return
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
+ if (os.path.isfile(save_path) and
+ self._write_version not in (
+ saver_pb2.SaverDef.V1, saver_pb2.SaverDef.LEGACY)):
+ raise ValueError("The specified path: %s is a file."
+ " Please specify only the path prefix"
+ " to the checkpoint files." % save_path)
logging.info("Restoring parameters from %s", save_path)
if context.in_graph_mode():
sess.run(self.saver_def.restore_op_name,
diff --git a/tensorflow/python/util/compat_internal.py b/tensorflow/python/util/compat_internal.py
new file mode 100644
index 0000000000..fee1d6fab7
--- /dev/null
+++ b/tensorflow/python/util/compat_internal.py
@@ -0,0 +1,34 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions for Python 2 vs. 3 compatibility that are private to TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+def path_to_str(path):
+ """Returns the file system path representation of a `PathLike` object,
+ else as it is.
+
+ Args:
+ path: An object that can be converted to path representation.
+
+ Returns:
+ A `str` object.
+ """
+ if hasattr(path, "__fspath__"):
+ path = as_str_any(path.__fspath__())
+ return path