diff options
Diffstat (limited to 'tensorflow/python')
50 files changed, 832 insertions, 228 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 2b4d5b8e0f..f5cd7885e7 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -76,6 +76,7 @@ py_library( ":layers", ":lib", ":list_ops", + ":manip_ops", ":math_ops", ":metrics", ":nn", @@ -1424,6 +1425,14 @@ tf_gen_op_wrapper_private_py( ) tf_gen_op_wrapper_private_py( + name = "manip_ops_gen", + visibility = [ + "//learning/brain/python/ops:__pkg__", + "//tensorflow/python/kernel_tests:__pkg__", + ], +) + +tf_gen_op_wrapper_private_py( name = "math_ops_gen", visibility = [ "//learning/brain/google/python/ops:__pkg__", @@ -1755,6 +1764,8 @@ py_library( ":linalg_grad", ":linalg_ops", ":logging_ops", + ":manip_grad", + ":manip_ops", ":math_grad", ":math_ops", ":platform", @@ -1878,6 +1889,29 @@ py_library( ) py_library( + name = "manip_grad", + srcs = ["ops/manip_grad.py"], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + ":framework_for_generated_wrappers", + ":manip_ops", + ], +) + +py_library( + name = "manip_ops", + srcs = ["ops/manip_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":framework_ops", + ":manip_ops_gen", + "//third_party/py/numpy", + ], +) + +py_library( name = "logging_ops", srcs = ["ops/logging_ops.py"], srcs_version = "PY2AND3", @@ -2339,6 +2373,8 @@ py_library( ":linalg_ops", ":logging_ops", ":lookup_ops", + ":manip_grad", + ":manip_ops", ":math_grad", ":math_ops", ":numerics", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index bc9ddec2a5..ea7604d30f 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -84,6 +84,7 @@ from tensorflow.python.feature_column import feature_column_lib as feature_colum from tensorflow.python.layers import layers from tensorflow.python.ops import bitwise_ops as bitwise from tensorflow.python.ops import image_ops as image +from tensorflow.python.ops import manip_ops as manip from tensorflow.python.ops import metrics from tensorflow.python.ops import nn from tensorflow.python.ops import sets @@ -241,6 +242,7 @@ _allowed_symbols.extend([ 'linalg', 'logging', 'losses', + 'manip', 'metrics', 'newaxis', 'nn', diff --git a/tensorflow/python/client/session_benchmark.py b/tensorflow/python/client/session_benchmark.py index 721bca91b7..da74855193 100644 --- a/tensorflow/python/client/session_benchmark.py +++ b/tensorflow/python/client/session_benchmark.py @@ -22,6 +22,7 @@ import time import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 5d318531d5..c4b7e4919b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -203,7 +203,7 @@ class Dataset(object): tensors: A nested structure of tensors. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return TensorDataset(tensors) @@ -216,7 +216,7 @@ class Dataset(object): 0th dimension. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return TensorSliceDataset(tensors) @@ -229,7 +229,7 @@ class Dataset(object): sparse_tensor: A `tf.SparseTensor`. Returns: - A `Dataset` of rank-(N-1) sparse tensors. + Dataset: A `Dataset` of rank-(N-1) sparse tensors. """ return SparseTensorSliceDataset(sparse_tensor) @@ -315,7 +315,7 @@ class Dataset(object): `generator`. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ if not callable(generator): raise TypeError("`generator` must be callable.") @@ -458,7 +458,7 @@ class Dataset(object): len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] Returns: - A `RangeDataset`. + Dataset: A `RangeDataset`. Raises: ValueError: if len(args) == 0. @@ -502,7 +502,7 @@ class Dataset(object): datasets: A nested structure of datasets. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return ZipDataset(datasets) @@ -528,7 +528,7 @@ class Dataset(object): dataset: `Dataset` to be concatenated. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return ConcatenateDataset(self, dataset) @@ -540,7 +540,7 @@ class Dataset(object): maximum number elements that will be buffered when prefetching. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return PrefetchDataset(self, buffer_size) @@ -558,12 +558,14 @@ class Dataset(object): - /path/to/dir/b.py - /path/to/dir/c.py + NOTE: The order of the file names returned can be non-deterministic. + Args: file_pattern: A string or scalar string `tf.Tensor`, representing the filename pattern that will be matched. Returns: - A `Dataset` of strings corresponding to file names. + Dataset: A `Dataset` of strings corresponding to file names. """ return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) @@ -580,7 +582,7 @@ class Dataset(object): indefinitely. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return RepeatDataset(self, count) @@ -604,7 +606,7 @@ class Dataset(object): iterated over. (Defaults to `True`.) Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration) @@ -617,7 +619,7 @@ class Dataset(object): If a filename is not provided, the dataset will be cached in memory. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return CacheDataset(self, filename) @@ -631,7 +633,7 @@ class Dataset(object): dataset, the new dataset will contain all elements of this dataset. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return TakeDataset(self, count) @@ -646,7 +648,7 @@ class Dataset(object): is -1, skips the entire dataset. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return SkipDataset(self, count) @@ -693,7 +695,7 @@ class Dataset(object): index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. Returns: - A `Dataset`. + Dataset: A `Dataset`. Raises: ValueError: if `num_shards` or `index` are illegal values. Note: error @@ -737,7 +739,7 @@ class Dataset(object): consecutive elements of this dataset to combine in a single batch. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return BatchDataset(self, batch_size) @@ -766,7 +768,7 @@ class Dataset(object): the empty string for string types. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values) @@ -782,7 +784,7 @@ class Dataset(object): specified, elements will be processed sequentially. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ if num_parallel_calls is None: return MapDataset(self, map_func) @@ -798,7 +800,7 @@ class Dataset(object): `Dataset`. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return FlatMapDataset(self, map_func) @@ -867,7 +869,7 @@ class Dataset(object): input element before cycling to another input element. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return InterleaveDataset(self, map_func, cycle_length, block_length) @@ -880,7 +882,7 @@ class Dataset(object): scalar `tf.bool` tensor. Returns: - A `Dataset`. + Dataset: A `Dataset`. """ return FilterDataset(self, predicate) @@ -901,10 +903,11 @@ class Dataset(object): Args: transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. + returns a `Dataset`. Returns: - The `Dataset` returned by applying `transformation_func` to this dataset. + Dataset: The `Dataset` returned by applying `transformation_func` to this + dataset. """ dataset = transformation_func(self) if not isinstance(dataset, Dataset): diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index e387e35740..e90ce3fb40 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -266,7 +266,7 @@ def map_structure(func, *structure, **check_types_dict): and the return value will contain the results in the same structure. Args: - func: A callable that acceps as many arguments are there are structures. + func: A callable that accepts as many arguments are there are structures. *structure: scalar, or tuple or list of constructed scalars and/or other tuples/lists, or scalars. Note: numpy arrays are considered scalars. **check_types_dict: only valid keyword argument is `check_types`. If set to @@ -479,8 +479,8 @@ def map_structure_up_to(shallow_tree, func, *inputs): The `inputs`, can be thought of as having the same structure as `shallow_tree`, but with leaf nodes that are themselves tree structures. - This function therefore will return something with the same base structure as - `shallow_tree`. + This function, therefore, will return something with the same base structure + as `shallow_tree`. Examples: diff --git a/tensorflow/python/data/util/sparse.py b/tensorflow/python/data/util/sparse.py index 5ebcb4ea81..5e6d224709 100644 --- a/tensorflow/python/data/util/sparse.py +++ b/tensorflow/python/data/util/sparse.py @@ -141,7 +141,7 @@ def serialize_sparse_tensors(tensors): tensors: a tensor structure to serialize. Returns: - `tensors` with any sparse tensors replaced by the their serialized version. + `tensors` with any sparse tensors replaced by their serialized version. """ ret = nest.pack_sequence_as(tensors, [ diff --git a/tensorflow/python/debug/cli/tensor_format.py b/tensorflow/python/debug/cli/tensor_format.py index d4aea76d65..e0759a8bc1 100644 --- a/tensorflow/python/debug/cli/tensor_format.py +++ b/tensorflow/python/debug/cli/tensor_format.py @@ -535,7 +535,7 @@ def numeric_summary(tensor): if not isinstance(tensor, np.ndarray) or not np.size(tensor): return debugger_cli_common.RichTextLines([ "No numeric summary available due to empty tensor."]) - elif (np.issubdtype(tensor.dtype, np.float) or + elif (np.issubdtype(tensor.dtype, np.floating) or np.issubdtype(tensor.dtype, np.complex) or np.issubdtype(tensor.dtype, np.integer)): counts = [ diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index c4b13a1045..8d355aa27f 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -222,7 +222,7 @@ def has_inf_or_nan(datum, tensor): # Also return False for data types that cannot be represented as numpy # arrays. return False - elif (np.issubdtype(tensor.dtype, np.float) or + elif (np.issubdtype(tensor.dtype, np.floating) or np.issubdtype(tensor.dtype, np.complex) or np.issubdtype(tensor.dtype, np.integer)): return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor)) diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py index 2f1654dda4..988442c971 100644 --- a/tensorflow/python/eager/execution_callbacks.py +++ b/tensorflow/python/eager/execution_callbacks.py @@ -153,7 +153,7 @@ def inf_nan_callback(op_type, continue numpy_dtype = output.dtype.as_numpy_dtype - if (np.issubdtype(numpy_dtype, np.float) or + if (np.issubdtype(numpy_dtype, np.floating) or np.issubdtype(numpy_dtype, np.complex) or np.issubdtype(numpy_dtype, np.integer)): try: diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 2bdec69303..706575985f 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -877,7 +877,7 @@ class BaseDNNWarmStartingTest(object): # Create a second DNNClassifier, warm-started from the first. Use a # learning_rate = 0.0 optimizer to check values (use SGD so we don't have - # accumulator values that change). Use a a new FeatureColumn with a + # accumulator values that change). Use a new FeatureColumn with a # different vocabulary for occupation. new_vocab_list = ['doctor', 'consultant', 'engineer'] new_vocab_file = os.path.join(self._ckpt_and_vocab_dir, diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index cccb9af4b2..3e9183cf1b 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -2003,7 +2003,7 @@ class BaseLinearWarmStartingTest(object): # Create a second LinearClassifier, warm-started from the first. Use a # learning_rate = 0.0 optimizer to check values (use SGD so we don't have - # accumulator values that change). Use a a new FeatureColumn with a + # accumulator values that change). Use a new FeatureColumn with a # different vocabulary for occupation. new_vocab_list = ['doctor', 'consultant', 'engineer'] new_vocab_file = os.path.join(self._ckpt_and_vocab_dir, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 6da890cd22..17fab3df4d 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -55,6 +55,7 @@ from tensorflow.python.training import saver from tensorflow.python.training import training from tensorflow.python.training import training_util from tensorflow.python.util import compat +from tensorflow.python.util import compat_internal from tensorflow.python.util import nest @@ -179,7 +180,7 @@ class Estimator(object): self._config = config # Model directory. - model_dir = compat.path_to_str(model_dir) + model_dir = compat_internal.path_to_str(model_dir) if (model_dir is not None) and (self._config.model_dir is not None): if model_dir != self._config.model_dir: # TODO(alanyee): remove this suppression after it is no longer needed diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py index 61a537022b..0c636a8da1 100644 --- a/tensorflow/python/estimator/run_config.py +++ b/tensorflow/python/estimator/run_config.py @@ -27,7 +27,7 @@ import six from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib -from tensorflow.python.util import compat +from tensorflow.python.util import compat_internal _USE_DEFAULT = object() @@ -444,7 +444,8 @@ class RunConfig(object): if tf_config: logging.info('TF_CONFIG environment variable: %s', tf_config) - model_dir = _get_model_dir(tf_config, compat.path_to_str(model_dir)) + model_dir = _get_model_dir(tf_config, + compat_internal.path_to_str(model_dir)) RunConfig._replace( self, diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py index b2ad4c4b65..2ee0732775 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py @@ -563,7 +563,7 @@ class Conv2DTranspose(tf_convolutional_layers.Conv2DTranspose, Layer): return dict(list(base_config.items()) + list(config.items())) -class Conv3DTranspose(tf_convolutional_layers.Conv3D, Layer): +class Conv3DTranspose(tf_convolutional_layers.Conv3DTranspose, Layer): """Transposed convolution layer (sometimes called Deconvolution). The need for transposed convolutions generally arises diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index c87b7652ad..3a6058054b 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -1602,6 +1602,19 @@ cuda_py_test( ) cuda_py_test( + name = "manip_ops_test", + size = "small", + srcs = ["manip_ops_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:manip_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + ], + tags = ["no_windows_gpu"], +) + +cuda_py_test( name = "matmul_op_test", size = "small", srcs = ["matmul_op_test.py"], diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index aae6d0a36e..7ec4624310 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -1162,6 +1162,27 @@ class InvertPermutationTest(test_util.TensorFlowTestCase): self.assertAllEqual(y.eval(), [2, 4, 3, 0, 1]) +class UnravelIndexTest(test_util.TensorFlowTestCase): + + def testUnravelIndex(self): + with self.test_session(): + for dtype in [dtypes.int32, dtypes.int64]: + indices_1 = constant_op.constant(1621, dtype=dtype) + dims_1 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_1 = array_ops.unravel_index(indices_1, dims_1) + self.assertAllEqual(out_1.eval(), [3, 1, 4, 1]) + + indices_2 = constant_op.constant([1621], dtype=dtype) + dims_2 = constant_op.constant([6, 7, 8, 9], dtype=dtype) + out_2 = array_ops.unravel_index(indices_2, dims_2) + self.assertAllEqual(out_2.eval(), [[3], [1], [4], [1]]) + + indices_3 = constant_op.constant([22, 41, 37], dtype=dtype) + dims_3 = constant_op.constant([7, 6], dtype=dtype) + out_3 = array_ops.unravel_index(indices_3, dims_3) + self.assertAllEqual(out_3.eval(), [[3, 6, 6], [4, 5, 1]]) + + class GuaranteeConstOpTest(test_util.TensorFlowTestCase): def testSimple(self): diff --git a/tensorflow/python/kernel_tests/constant_op_test.py b/tensorflow/python/kernel_tests/constant_op_test.py index 030c690167..16e56349c4 100644 --- a/tensorflow/python/kernel_tests/constant_op_test.py +++ b/tensorflow/python/kernel_tests/constant_op_test.py @@ -454,18 +454,19 @@ class ZerosLikeTest(test.TestCase): def testZerosLikeCPU(self): for dtype in [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int8, - dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16, dtypes_lib.int32, - dtypes_lib.int64, dtypes_lib.bool, dtypes_lib.complex64, - dtypes_lib.complex128, dtypes_lib.string + dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64, + dtypes_lib.int8, dtypes_lib.uint8, dtypes_lib.int16, dtypes_lib.uint16, + dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.bool, + dtypes_lib.complex64, dtypes_lib.complex128, dtypes_lib.string ]: self._compareZeros(dtype, fully_defined_shape=False, use_gpu=False) self._compareZeros(dtype, fully_defined_shape=True, use_gpu=False) def testZerosLikeGPU(self): for dtype in [ - dtypes_lib.float32, dtypes_lib.float64, dtypes_lib.int32, - dtypes_lib.bool, dtypes_lib.int64, dtypes_lib.string + dtypes_lib.half, dtypes_lib.float32, dtypes_lib.float64, + dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.complex64, + dtypes_lib.complex128, dtypes_lib.bool ]: self._compareZeros(dtype, fully_defined_shape=False, use_gpu=True) self._compareZeros(dtype, fully_defined_shape=True, use_gpu=True) diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 3e9bd3dade..edfb20d6a2 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -24,6 +24,7 @@ import time import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import layers from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op @@ -519,7 +520,7 @@ class Conv2DTest(test.TestCase): dilations=[2, 2], padding="VALID") - # TODO this currently fails. + # TODO(yzhwang): this currently fails. # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1], # filter_in_sizes=[2, 2, 1, 1], # strides=[4, 4], padding="SAME", diff --git a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py index ead55cd03b..89fd26c544 100644 --- a/tensorflow/python/kernel_tests/decode_jpeg_op_test.py +++ b/tensorflow/python/kernel_tests/decode_jpeg_op_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import os import time +from six.moves import xrange from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops diff --git a/tensorflow/python/kernel_tests/io_ops_test.py b/tensorflow/python/kernel_tests/io_ops_test.py index f91875c6f0..61944f7e31 100644 --- a/tensorflow/python/kernel_tests/io_ops_test.py +++ b/tensorflow/python/kernel_tests/io_ops_test.py @@ -1,4 +1,4 @@ -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 00c6706593..197dbf44af 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -953,14 +953,14 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): # Compute the expected loss 'manually'. total = np.zeros((batch_size,)) for b in range(batch_size): - for i in range(dims): - for j in range(dims): + for i in range(dims - 1): + for j in range(i + 1, dims): x = self._predictions[b, i].item() - self._predictions[b, j].item() y = self._labels[b, i].item() - self._labels[b, j].item() diff = (x - y) total[b] += (diff * diff) - self._expected_losses = np.divide(total, 9.0) + self._expected_losses = np.divide(total, 3.0) def testValueErrorThrownWhenWeightIsNone(self): with self.test_session(): @@ -1059,8 +1059,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): [[4, 8, 12], [1, 2, 3], [4, 5, 6]], [[8, 1, 3], [7, 8, 9], [10, 11, 12]], ]) - self._test_valid_weights( - labels, predictions, expected_loss=122.22222) + self._test_valid_weights(labels, predictions, expected_loss=137.5) def test3dWeightedScalar(self): labels = np.array([ @@ -1073,8 +1072,7 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) weight = 3.0 self._test_valid_weights( - labels, predictions, expected_loss=weight * 122.22222, - weights=weight) + labels, predictions, expected_loss=weight * 137.5, weights=weight) def _test_invalid_weights( self, labels, predictions, weights=1.0): @@ -1124,7 +1122,9 @@ class MeanPairwiseSquaredErrorTest(test.TestCase): ]) self._test_valid_weights( # TODO(ptucker): This doesn't look right. - labels, predictions, expected_loss=9 * 122.22222, + labels, + predictions, + expected_loss=9 * 137.5, weights=np.ones((2, 3, 3))) def testLossWithAllZeroBatchSpecificWeights(self): diff --git a/tensorflow/python/kernel_tests/manip_ops_test.py b/tensorflow/python/kernel_tests/manip_ops_test.py new file mode 100644 index 0000000000..b8200ac0cb --- /dev/null +++ b/tensorflow/python/kernel_tests/manip_ops_test.py @@ -0,0 +1,138 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for manip_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import manip_ops +from tensorflow.python.platform import test as test_lib + +# pylint: disable=g-import-not-at-top +try: + from distutils.version import StrictVersion as Version + # numpy.roll for multiple shifts was introduced in numpy version 1.12.0 + NP_ROLL_CAN_MULTISHIFT = Version(np.version.version) >= Version("1.12.0") +except ImportError: + NP_ROLL_CAN_MULTISHIFT = False +# pylint: enable=g-import-not-at-top + + +class RollTest(test_util.TensorFlowTestCase): + + def _testRoll(self, np_input, shift, axis): + expected_roll = np.roll(np_input, shift, axis) + with self.test_session(): + roll = manip_ops.roll(np_input, shift, axis) + self.assertAllEqual(roll.eval(), expected_roll) + + def _testGradient(self, np_input, shift, axis): + with self.test_session(): + inx = constant_op.constant(np_input.tolist()) + xs = list(np_input.shape) + y = manip_ops.roll(inx, shift, axis) + # Expected y's shape to be the same + ys = xs + jacob_t, jacob_n = gradient_checker.compute_gradient( + inx, xs, y, ys, x_init_value=np_input) + self.assertAllClose(jacob_t, jacob_n, rtol=1e-5, atol=1e-5) + + def _testAll(self, np_input, shift, axis): + self._testRoll(np_input, shift, axis) + if np_input.dtype == np.float32: + self._testGradient(np_input, shift, axis) + + def testIntTypes(self): + for t in [np.int32, np.int64]: + self._testAll(np.random.randint(-100, 100, (5)).astype(t), 3, 0) + if NP_ROLL_CAN_MULTISHIFT: + self._testAll( + np.random.randint(-100, 100, (4, 4, 3)).astype(t), [1, -2, 3], + [0, 1, 2]) + self._testAll( + np.random.randint(-100, 100, (4, 2, 1, 3)).astype(t), [0, 1, -2], + [1, 2, 3]) + + def testFloatTypes(self): + for t in [np.float32, np.float64]: + self._testAll(np.random.rand(5).astype(t), 2, 0) + if NP_ROLL_CAN_MULTISHIFT: + self._testAll(np.random.rand(3, 4).astype(t), [1, 2], [1, 0]) + self._testAll(np.random.rand(1, 3, 4).astype(t), [1, 0, -3], [0, 1, 2]) + + def testComplexTypes(self): + for t in [np.complex64, np.complex128]: + x = np.random.rand(4, 4).astype(t) + self._testAll(x + 1j * x, 2, 0) + if NP_ROLL_CAN_MULTISHIFT: + x = np.random.rand(2, 5).astype(t) + self._testAll(x + 1j * x, [1, 2], [1, 0]) + x = np.random.rand(3, 2, 1, 1).astype(t) + self._testAll(x + 1j * x, [2, 1, 1, 0], [0, 3, 1, 2]) + + def testRollInputMustVectorHigherRaises(self): + tensor = 7 + shift = 1 + axis = 0 + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "input must be 1-D or higher"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollAxisMustBeScalarOrVectorRaises(self): + tensor = [[1, 2], [3, 4]] + shift = 1 + axis = [[0, 1]] + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "axis must be a scalar or a 1-D vector"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollShiftMustBeScalarOrVectorRaises(self): + tensor = [[1, 2], [3, 4]] + shift = [[0, 1]] + axis = 1 + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "shift must be a scalar or a 1-D vector"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollShiftAndAxisMustBeSameSizeRaises(self): + tensor = [[1, 2], [3, 4]] + shift = [1] + axis = [0, 1] + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "shift and axis must have the same size"): + manip_ops.roll(tensor, shift, axis).eval() + + def testRollAxisOutOfRangeRaises(self): + tensor = [1, 2] + shift = 1 + axis = 1 + with self.test_session(): + with self.assertRaisesRegexp(errors_impl.InvalidArgumentError, + "is out of range"): + manip_ops.roll(tensor, shift, axis).eval() + + +if __name__ == "__main__": + test_lib.main() diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py index 0c77d1db92..daa42938e6 100644 --- a/tensorflow/python/kernel_tests/rnn_test.py +++ b/tensorflow/python/kernel_tests/rnn_test.py @@ -23,6 +23,7 @@ import timeit import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as contrib_rnn from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session diff --git a/tensorflow/python/kernel_tests/tensordot_op_test.py b/tensorflow/python/kernel_tests/tensordot_op_test.py index f1670a47f5..8ad29afd0a 100644 --- a/tensorflow/python/kernel_tests/tensordot_op_test.py +++ b/tensorflow/python/kernel_tests/tensordot_op_test.py @@ -66,7 +66,7 @@ class TensordotTest(test_lib.TestCase): a = [[1, 2], [3, 4]] b = [[1, 2], [3, 4]] # Invalid static axes. - for axes_value in -1, 0, [1], [[1]], [[1], [0, 1]]: + for axes_value in -1, 3, [1], [[1]], [[1], [0, 1]]: with self.assertRaises(ValueError): math_ops.tensordot(a, b, axes_value) @@ -91,7 +91,7 @@ class TensordotTest(test_lib.TestCase): # Test case for 11950 def test_valid_axis(self): - for axes_value in [1, 2], [[1], [2]]: + for axes_value in [1, 2], [[1], [2]], [[], []], 0: with self.test_session() as sess: np_a = np.ones((3, 3)) np_b = np.array([2, 3, 1])[None, None] @@ -105,29 +105,29 @@ class TensordotTest(test_lib.TestCase): self.assertAllEqual(tf_ans, np_ans) def test_partial_shape_inference(self): - a = array_ops.placeholder(dtypes.float32) - b = array_ops.placeholder(dtypes.float32) - axes = ([1], [0]) - output = math_ops.tensordot(a, b, axes) - self.assertEqual(output.get_shape().ndims, None) - a.set_shape([None, 2]) - b.set_shape([2, 3]) - output = math_ops.tensordot(a, b, axes) - output_shape = output.get_shape() - self.assertEqual(output_shape.ndims, 2) - output_shape = output_shape.as_list() - self.assertEqual(output_shape[0], None) - self.assertEqual(output_shape[1], 3) - a = array_ops.placeholder(dtypes.float32) - b = array_ops.placeholder(dtypes.float32) - a.set_shape([2, 2]) - b.set_shape([2, None]) - output = math_ops.tensordot(a, b, axes) - output_shape = output.get_shape() - self.assertEqual(output_shape.ndims, 2) - output_shape = output_shape.as_list() - self.assertEqual(output_shape[0], 2) - self.assertEqual(output_shape[1], None) + for axes in ([1], [0]), 1: + a = array_ops.placeholder(dtypes.float32) + b = array_ops.placeholder(dtypes.float32) + output = math_ops.tensordot(a, b, axes) + self.assertEqual(output.get_shape().ndims, None) + a.set_shape([None, 2]) + b.set_shape([2, 3]) + output = math_ops.tensordot(a, b, axes) + output_shape = output.get_shape() + self.assertEqual(output_shape.ndims, 2) + output_shape = output_shape.as_list() + self.assertEqual(output_shape[0], None) + self.assertEqual(output_shape[1], 3) + a = array_ops.placeholder(dtypes.float32) + b = array_ops.placeholder(dtypes.float32) + a.set_shape([2, 2]) + b.set_shape([2, None]) + output = math_ops.tensordot(a, b, axes) + output_shape = output.get_shape() + self.assertEqual(output_shape.ndims, 2) + output_shape = output_shape.as_list() + self.assertEqual(output_shape[0], 2) + self.assertEqual(output_shape[1], None) def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): @@ -196,8 +196,8 @@ def _get_tensordot_tests(dtype_, rank_a_, rank_b_, num_dims_, dynamic_shape_): low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) b_np = np.random.uniform( low=-1.0, high=1.0, size=np.prod(shape)).reshape(shape).astype(dtype_) - all_axes = [1] - if a_np.ndim > 1: + all_axes = [0, 1] + if a_np.ndim > 2: all_axes.append(a_np.ndim - 1) for axes in all_axes: np_ans = np.tensordot(a_np, b_np, axes=axes) diff --git a/tensorflow/python/kernel_tests/topk_op_test.py b/tensorflow/python/kernel_tests/topk_op_test.py index efb5b9f364..6ab931fdb9 100644 --- a/tensorflow/python/kernel_tests/topk_op_test.py +++ b/tensorflow/python/kernel_tests/topk_op_test.py @@ -58,7 +58,7 @@ class TopKTest(test.TestCase): # Do some special casing of equality of indices: if indices # are not the same, but values are floating type, ensure that # the values are within epsilon of each other. - if not np.issubdtype(np_expected_values.dtype, np.float): + if not np.issubdtype(np_expected_values.dtype, np.floating): # Values are not floating point type; check indices exactly self.assertAllEqual(np_expected_indices, indices) else: diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 79c421f4c9..e8dba3cea3 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -1094,7 +1094,7 @@ class SeparableConv1D(_SeparableConv): strides = (1, 1, 1) + self.strides spatial_start_dim = 2 - # Explictly broadcast inputs and kernels to 4D. + # Explicitly broadcast inputs and kernels to 4D. # TODO(fchollet): refactor when a native separable_conv1d op is available. inputs = array_ops.expand_dims(inputs, spatial_start_dim) depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0) @@ -1904,6 +1904,7 @@ class Conv3DTranspose(Conv3D): dtype=self.dtype) else: self.bias = None + self.built = True def call(self, inputs): inputs_shape = array_ops.shape(inputs) @@ -1974,6 +1975,8 @@ class Conv3DTranspose(Conv3D): if self.use_bias: outputs_shape = outputs.shape.as_list() + if outputs_shape[0] is None: + outputs_shape[0] = -1 if self.data_format == 'channels_first': outputs_4d = array_ops.reshape(outputs, [ outputs_shape[0], outputs_shape[1], @@ -2007,11 +2010,11 @@ class Conv3DTranspose(Conv3D): output_shape[c_axis] = self.filters output_shape[d_axis] = utils.deconv_output_length( - output_shape[d_axis], stride_d, kernel_d, self.padding) + output_shape[d_axis], kernel_d, self.padding, stride_d) output_shape[h_axis] = utils.deconv_output_length( - output_shape[h_axis], stride_h, kernel_h, self.padding) + output_shape[h_axis], kernel_h, self.padding, stride_h) output_shape[w_axis] = utils.deconv_output_length( - output_shape[w_axis], stride_w, kernel_w, self.padding) + output_shape[w_axis], kernel_w, self.padding, stride_w) return tensor_shape.TensorShape(output_shape) diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index e8be347799..7407d9a7b3 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -81,7 +81,7 @@ def normalize_tuple(value, n, name): for single_value in value_tuple: try: int(single_value) - except ValueError: + except (ValueError, TypeError): raise ValueError('The `' + name + '` argument must be a tuple of ' + str(n) + ' integers. Received: ' + str(value) + ' ' 'including element ' + str(single_value) + ' of type' + diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index e3902f5a8a..ad409ad7e5 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -35,6 +35,7 @@ See the @{$python/array_ops} guide. @@reshape @@squeeze @@expand_dims +@@unravel_index @@meshgrid @@slice @@strided_slice @@ -1589,9 +1590,9 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): Args: tensor: A `Tensor`. - dtype: A type for the returned `Tensor`. Must be `float32`, `float64`, - `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`, - `complex64`, `complex128` or `bool`. + dtype: A type for the returned `Tensor`. Must be `float16`, `float32`, + `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`, + `complex64`, `complex128`, `bool` or `string`. name: A name for the operation (optional). optimize: if true, attempt to statically determine the shape of 'tensor' and encode it as a constant. diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py index 7dbccf1caf..ac03d30fcd 100644 --- a/tensorflow/python/ops/functional_ops.py +++ b/tensorflow/python/ops/functional_ops.py @@ -458,7 +458,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, For example, if `elems` is `(t1, [t2, t3])` and `initializer` is `[i1, i2]` then an appropriate signature for `fn` in `python2` is: - `fn = lambda (acc_p1, acc_p2), (t1 [t2, t3]):` and `fn` must return a list, + `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the one that works in `python3`, is: `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 28b26a09a5..9f06c0ee1f 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import image_grad # pylint: disable=unused-import from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import from tensorflow.python.ops import logging_ops # pylint: disable=unused-import +from tensorflow.python.ops import manip_grad # pylint: disable=unused-import from tensorflow.python.ops import math_grad # pylint: disable=unused-import from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index 3b0b5a978c..de12c5f63f 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -49,6 +49,10 @@ See the @{$python/image} guide. @@grayscale_to_rgb @@hsv_to_rgb @@rgb_to_hsv +@@rgb_to_yiq +@@yiq_to_rgb +@@rgb_to_yuv +@@yuv_to_rgb @@convert_image_dtype @@adjust_brightness @@random_brightness diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 2c231ef56c..14a38f25d1 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -1508,7 +1508,7 @@ def sample_distorted_bounding_box(image_size, bounding_boxes, seed=None, seed2=None, - min_object_covered=None, + min_object_covered=0.1, aspect_ratio_range=None, area_range=None, max_attempts=None, @@ -1669,3 +1669,107 @@ def non_max_suppression(boxes, return gen_image_ops._non_max_suppression_v2(boxes, scores, max_output_size, iou_threshold) # pylint: enable=protected-access + + +_rgb_to_yiq_kernel = [[0.299, 0.59590059, + 0.2115], [0.587, -0.27455667, -0.52273617], + [0.114, -0.32134392, 0.31119955]] + + +def rgb_to_yiq(images): + """Converts one or more images from RGB to YIQ. + + Outputs a tensor of the same shape as the `images` tensor, containing the YIQ + value of the pixels. + The output is only well defined if the value in images are in [0,1]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor( + _rgb_to_yiq_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) + + +_yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021], + [0.6208248, -0.64720424, 1.70423049]] + + +def yiq_to_rgb(images): + """Converts one or more images from YIQ to RGB. + + Outputs a tensor of the same shape as the `images` tensor, containing the RGB + value of the pixels. + The output is only well defined if the Y value in images are in [0,1], + I value are in [-0.5957,0.5957] and Q value are in [-0.5226,0.5226]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor( + _yiq_to_rgb_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) + + +_rgb_to_yuv_kernel = [[0.299, -0.14714119, + 0.61497538], [0.587, -0.28886916, -0.51496512], + [0.114, 0.43601035, -0.10001026]] + + +def rgb_to_yuv(images): + """Converts one or more images from RGB to YUV. + + Outputs a tensor of the same shape as the `images` tensor, containing the YUV + value of the pixels. + The output is only well defined if the value in images are in [0,1]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor( + _rgb_to_yuv_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) + + +_yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185], + [1.13988303, -0.58062185, 0]] + + +def yuv_to_rgb(images): + """Converts one or more images from YUV to RGB. + + Outputs a tensor of the same shape as the `images` tensor, containing the RGB + value of the pixels. + The output is only well defined if the Y value in images are in [0,1], + U and V value are in [-0.5,0.5]. + + Args: + images: 2-D or higher rank. Image data to convert. Last dimension must be + size 3. + + Returns: + images: tensor with the same shape as `images`. + """ + images = ops.convert_to_tensor(images, name='images') + kernel = ops.convert_to_tensor( + _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel') + ndims = images.get_shape().ndims + return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]]) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 47dd8231c0..b12bd3d5b0 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -85,6 +85,64 @@ class RGBToHSVTest(test_util.TensorFlowTestCase): self.assertAllClose(rgb_tf, rgb_np) +class RGBToYIQTest(test_util.TensorFlowTestCase): + + def testBatch(self): + # Build an arbitrary RGB image + np.random.seed(7) + batch_size = 5 + shape = (batch_size, 2, 7, 3) + + for nptype in [np.float32, np.float64]: + inp = np.random.rand(*shape).astype(nptype) + + # Convert to YIQ and back, as a batch and individually + with self.test_session(use_gpu=True) as sess: + batch0 = constant_op.constant(inp) + batch1 = image_ops.rgb_to_yiq(batch0) + batch2 = image_ops.yiq_to_rgb(batch1) + split0 = array_ops.unstack(batch0) + split1 = list(map(image_ops.rgb_to_yiq, split0)) + split2 = list(map(image_ops.yiq_to_rgb, split1)) + join1 = array_ops.stack(split1) + join2 = array_ops.stack(split2) + batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2]) + + # Verify that processing batch elements together is the same as separate + self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4) + + +class RGBToYUVTest(test_util.TensorFlowTestCase): + + def testBatch(self): + # Build an arbitrary RGB image + np.random.seed(7) + batch_size = 5 + shape = (batch_size, 2, 7, 3) + + for nptype in [np.float32, np.float64]: + inp = np.random.rand(*shape).astype(nptype) + + # Convert to YUV and back, as a batch and individually + with self.test_session(use_gpu=True) as sess: + batch0 = constant_op.constant(inp) + batch1 = image_ops.rgb_to_yuv(batch0) + batch2 = image_ops.yuv_to_rgb(batch1) + split0 = array_ops.unstack(batch0) + split1 = list(map(image_ops.rgb_to_yuv, split0)) + split2 = list(map(image_ops.yuv_to_rgb, split1)) + join1 = array_ops.stack(split1) + join2 = array_ops.stack(split2) + batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2]) + + # Verify that processing batch elements together is the same as separate + self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4) + self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4) + + class GrayscaleToRGBTest(test_util.TensorFlowTestCase): def _RGBToGrayscale(self, images): @@ -1839,6 +1897,26 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): self.assertAllEqual([3], end.get_shape().as_list()) self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list()) + def testDefaultMinObjectCovered(self): + # By default min_object_covered=0.1 if not provided + with self.test_session(use_gpu=True): + image_size = constant_op.constant( + [40, 50, 1], shape=[3], dtype=dtypes.int32) + bounding_box = constant_op.constant( + [0.0, 0.0, 1.0, 1.0], + shape=[4], + dtype=dtypes.float32, + ) + begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box( + image_size=image_size, + bounding_boxes=bounding_box, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0)) + + self.assertAllEqual([3], begin.get_shape().as_list()) + self.assertAllEqual([3], end.get_shape().as_list()) + self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list()) + class ResizeImagesTest(test_util.TensorFlowTestCase): @@ -3092,6 +3170,40 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): boxes, scores, max_output_size, iou_threshold).eval() self.assertAllClose(selected_indices, [3, 0, 5]) + def testInvalidShape(self): + # The boxes should be 2D of shape [num_boxes, 4]. + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 2 but is rank 1"): + boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) + + with self.assertRaisesRegexp(ValueError, "Dimension must be 4 but is 3"): + boxes = constant_op.constant([[0.0, 0.0, 1.0]]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) + + # The scores should be 1D of shape [num_boxes]. + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 1 but is rank 2"): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([[0.9]]) + image_ops.non_max_suppression(boxes, scores, 3, 0.5) + + # The max_output_size should be a scaler (0-D). + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 0 but is rank 1"): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, [3], 0.5) + + # The iou_threshold should be a scaler (0-D). + with self.assertRaisesRegexp(ValueError, + "Shape must be rank 0 but is rank 2"): + boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) + scores = constant_op.constant([0.9]) + image_ops.non_max_suppression(boxes, scores, 3, [[0.5]]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 13a32c83d9..3cbbf3412a 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -277,20 +277,28 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): # https://j-towns.github.io/papers/svd-derivative.pdf a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) + grad_s_mat = array_ops.matrix_diag(grad_s) - if op.get_attr("compute_uv"): - # TODO(rmlarsen): Make this work with complex types. - if a.dtype.is_complex: - raise NotImplementedError( - "SVD gradient is not implemented for complex types and " - "compute_uv=True.") - grad_u_shape = grad_u.get_shape().with_rank_at_least(2) - grad_v_shape = grad_v.get_shape().with_rank_at_least(2) - m = a_shape[-2].merge_with(grad_u_shape[-2]) - n = a_shape[-1].merge_with(grad_v_shape[-2]) - batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( - grad_v_shape[:-2]) - a_shape = batch_shape.concatenate([m, n]) + if not op.get_attr("compute_uv"): + s, u, v = linalg_ops.svd(a, compute_uv=True) + grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) + grad_a.set_shape(a_shape) + return grad_a + + full_matrices = op.get_attr("full_matrices") + + # TODO(rmlarsen): Make this work with complex types. + if a.dtype.is_complex: + raise NotImplementedError( + "SVD gradient is not implemented for complex types and " + "compute_uv=True.") + grad_u_shape = grad_u.get_shape().with_rank_at_least(2) + grad_v_shape = grad_v.get_shape().with_rank_at_least(2) + m = a_shape[-2].merge_with(grad_u_shape[-2]) + n = a_shape[-1].merge_with(grad_v_shape[-2]) + batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( + grad_v_shape[:-2]) + a_shape = batch_shape.concatenate([m, n]) m = a_shape[-2].value n = a_shape[-1].value @@ -300,12 +308,9 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") - if not op.get_attr("compute_uv"): - s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True) - else: - s = op.outputs[0] - u = op.outputs[1] - v = op.outputs[2] + s = op.outputs[0] + u = op.outputs[1] + v = op.outputs[2] use_adjoint = False if m > n: @@ -317,19 +322,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): - grad_s_mat = array_ops.matrix_diag(grad_s) - if not op.get_attr("compute_uv"): - if use_adjoint: - grad_a = math_ops.matmul( - v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True) - else: - grad_a = math_ops.matmul(u, - math_ops.matmul( - grad_s_mat, v[..., :, :m], adjoint_b=True)) - grad_a.set_shape(a_shape) - return grad_a - - if op.get_attr("full_matrices") and abs(m - n) > 1: + if full_matrices and abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1 " "when full_matrices is True") @@ -371,7 +364,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v): gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) - if op.get_attr("full_matrices"): + if full_matrices: v2 = v[..., :, m:n] grad_v2 = grad_v[..., :, m:n] diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index e75a9b22e4..84afbf0627 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -547,12 +547,13 @@ def mean_pairwise_squared_error( num_present_per_batch = _num_present(diffs, weights, per_batch=True) term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, - num_present_per_batch) + num_present_per_batch - 1) sum_diff = math_ops.reduce_sum( diffs, reduction_indices=reduction_indices, keep_dims=True) - term2 = 2.0 * _safe_div(math_ops.square(sum_diff), - math_ops.square(num_present_per_batch)) + term2 = 2.0 * _safe_div( + math_ops.square(sum_diff), + math_ops.multiply(num_present_per_batch, num_present_per_batch - 1)) weighted_losses = math_ops.multiply(term1 - term2, weights) loss = math_ops.reduce_sum(weighted_losses) diff --git a/tensorflow/python/ops/manip_grad.py b/tensorflow/python/ops/manip_grad.py new file mode 100644 index 0000000000..bb2069359d --- /dev/null +++ b/tensorflow/python/ops/manip_grad.py @@ -0,0 +1,31 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradients for operators defined in manip_ops.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import manip_ops + + +@ops.RegisterGradient("Roll") +def _RollGrad(op, grad): + # The gradient is just the roll reversed + shift = op.inputs[1] + axis = op.inputs[2] + roll_grad = manip_ops.roll(grad, -shift, axis) + return roll_grad, None, None diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py new file mode 100644 index 0000000000..91e15b47b9 --- /dev/null +++ b/tensorflow/python/ops/manip_ops.py @@ -0,0 +1,38 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operators for manipulating tensors. + +@@roll +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops +from tensorflow.python.util.all_util import remove_undocumented + + +# pylint: disable=protected-access +def roll(input, shift, axis): # pylint: disable=redefined-builtin + return _gen_manip_ops.roll(input, shift, axis) + + +roll.__doc__ = _gen_manip_ops.roll.__doc__ +# pylint: enable=protected-access + +_allowed_symbols = ['roll'] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 827e3caa36..9a8ac93de9 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -2826,10 +2826,14 @@ def tensordot(a, b, axes, name=None): """Generates two sets of contraction axes for the two tensor arguments.""" a_shape = a.get_shape() if isinstance(axes, compat.integral_types): - if axes < 1: - raise ValueError("'axes' must be at least 1.") + if axes < 0: + raise ValueError("'axes' must be at least 0.") if a_shape.ndims is not None: - return range(a_shape.ndims - axes, a_shape.ndims), range(axes) + if axes > a_shape.ndims: + raise ValueError("'axes' must not be larger than the number of " + "dimensions of tensor %s." % a) + return (list(xrange(a_shape.ndims - axes, a_shape.ndims)), + list(xrange(axes))) else: rank = array_ops.rank(a) return (range(rank - axes, rank, dtype=dtypes.int32), diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py index 24c6f64f0a..da80e72071 100644 --- a/tensorflow/python/ops/rnn.py +++ b/tensorflow/python/ops/rnn.py @@ -1127,6 +1127,12 @@ def raw_rnn(cell, loop_fn, def _copy_some_through(current, candidate): """Copy some tensors through via array_ops.where.""" def copy_fn(cur_i, cand_i): + # TensorArray and scalar get passed through. + if isinstance(cur_i, tensor_array_ops.TensorArray): + return cand_i + if cur_i.shape.ndims == 0: + return cand_i + # Otherwise propagate the old or the new value. with ops.colocate_with(cand_i): return array_ops.where(elements_finished, cur_i, cand_i) return nest.map_structure(copy_fn, current, candidate) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index 30bf4e4ef1..009d1dc3b9 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -25,6 +25,7 @@ import sys as _sys # Imports the following modules so that @RegisterGradient get executed. from tensorflow.python.ops import array_grad from tensorflow.python.ops import data_flow_grad +from tensorflow.python.ops import manip_grad from tensorflow.python.ops import math_grad from tensorflow.python.ops import sparse_grad from tensorflow.python.ops import spectral_grad @@ -42,11 +43,13 @@ from tensorflow.python.ops.special_math_ops import * # TODO(vrv): Switch to import * once we're okay with exposing the module. from tensorflow.python.ops.confusion_matrix import confusion_matrix from tensorflow.python.ops.control_flow_ops import Assert +from tensorflow.python.ops.control_flow_ops import case +from tensorflow.python.ops.control_flow_ops import cond from tensorflow.python.ops.control_flow_ops import group from tensorflow.python.ops.control_flow_ops import no_op +# pylint: disable=redefined-builtin from tensorflow.python.ops.control_flow_ops import tuple -from tensorflow.python.ops.control_flow_ops import cond -from tensorflow.python.ops.control_flow_ops import case +# pylint: enable=redefined-builtin from tensorflow.python.ops.control_flow_ops import while_loop from tensorflow.python.ops.data_flow_ops import * from tensorflow.python.ops.functional_ops import * @@ -59,6 +62,7 @@ from tensorflow.python.ops.logging_ops import Print from tensorflow.python.ops.logging_ops import get_summary_op from tensorflow.python.ops.lookup_ops import initialize_all_tables from tensorflow.python.ops.lookup_ops import tables_initializer +from tensorflow.python.ops.manip_ops import * from tensorflow.python.ops.math_ops import * from tensorflow.python.ops.numerics import * from tensorflow.python.ops.parsing_ops import * @@ -105,6 +109,7 @@ from tensorflow.python.ops import init_ops as _init_ops from tensorflow.python.ops import io_ops as _io_ops from tensorflow.python.ops import linalg_ops as _linalg_ops from tensorflow.python.ops import logging_ops as _logging_ops +from tensorflow.python.ops import manip_ops as _manip_ops from tensorflow.python.ops import math_ops as _math_ops from tensorflow.python.ops import numerics as _numerics from tensorflow.python.ops import parsing_ops as _parsing_ops @@ -264,34 +269,36 @@ _allowed_symbols = (_allowed_symbols_array_ops + _allowed_symbols_misc + _allowed_symbols_partitioned_variables) -remove_undocumented(__name__, _allowed_symbols, - [_sys.modules[__name__], - _array_ops, - _check_ops, - _clip_ops, - _confusion_matrix, - _control_flow_ops, - _constant_op, - _data_flow_ops, - _functional_ops, - _gradients, - _histogram_ops, - _init_ops, - _io_ops, - _linalg_ops, - _logging_ops, - _math_ops, - _numerics, - _parsing_ops, - _partitioned_variables, - _random_ops, - _script_ops, - _session_ops, - _sparse_ops, - _special_math_ops, - _state_ops, - _string_ops, - _template, - _tensor_array_ops, - _variable_scope, - _variables,]) +remove_undocumented(__name__, _allowed_symbols, [ + _sys.modules[__name__], + _array_ops, + _check_ops, + _clip_ops, + _confusion_matrix, + _control_flow_ops, + _constant_op, + _data_flow_ops, + _functional_ops, + _gradients, + _histogram_ops, + _init_ops, + _io_ops, + _linalg_ops, + _logging_ops, + _manip_ops, + _math_ops, + _numerics, + _parsing_ops, + _partitioned_variables, + _random_ops, + _script_ops, + _session_ops, + _sparse_ops, + _special_math_ops, + _state_ops, + _string_ops, + _template, + _tensor_array_ops, + _variable_scope, + _variables, +]) diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index ddfd6be6da..bebf1d5e0d 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -235,13 +235,10 @@ def load(sess, tags, export_dir, **saver_kwargs): asset_tensors_dictionary = _get_asset_tensors(export_dir, meta_graph_def_to_load) - main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load) + main_op_tensor = ( + _get_main_op_tensor(meta_graph_def_to_load) or + (_get_legacy_init_op_tensor(meta_graph_def_to_load))) if main_op_tensor is not None: sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary) - else: - legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load) - if legacy_init_op_tensor is not None: - sess.run( - fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary) return meta_graph_def_to_load diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index 0ddf09260b..affa97062a 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -72,7 +72,8 @@ def freeze_graph_with_def_protos(input_graph_def, variable_names_blacklist="", input_meta_graph_def=None, input_saved_model_dir=None, - saved_model_tags=None): + saved_model_tags=None, + checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" del restore_op_name, filename_tensor_name # Unused by updated loading code. @@ -100,7 +101,8 @@ def freeze_graph_with_def_protos(input_graph_def, _ = importer.import_graph_def(input_graph_def, name="") with session.Session() as sess: if input_saver_def: - saver = saver_lib.Saver(saver_def=input_saver_def) + saver = saver_lib.Saver( + saver_def=input_saver_def, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) elif input_meta_graph_def: restorer = saver_lib.import_meta_graph( @@ -124,7 +126,8 @@ def freeze_graph_with_def_protos(input_graph_def, # 'global_step' or a similar housekeeping element) so skip it. continue var_list[key] = tensor - saver = saver_lib.Saver(var_list=var_list) + saver = saver_lib.Saver( + var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: sess.run(initializer_nodes.split(",")) @@ -217,7 +220,8 @@ def freeze_graph(input_graph, variable_names_blacklist="", input_meta_graph=None, input_saved_model_dir=None, - saved_model_tags=tag_constants.SERVING): + saved_model_tags=tag_constants.SERVING, + checkpoint_version=saver_pb2.SaverDef.V2): """Converts all variables in a graph and checkpoint into constants.""" input_graph_def = None if input_saved_model_dir: @@ -233,10 +237,21 @@ def freeze_graph(input_graph, if input_saver: input_saver_def = _parse_input_saver_proto(input_saver, input_binary) freeze_graph_with_def_protos( - input_graph_def, input_saver_def, input_checkpoint, output_node_names, - restore_op_name, filename_tensor_name, output_graph, clear_devices, - initializer_nodes, variable_names_whitelist, variable_names_blacklist, - input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(",")) + input_graph_def, + input_saver_def, + input_checkpoint, + output_node_names, + restore_op_name, + filename_tensor_name, + output_graph, + clear_devices, + initializer_nodes, + variable_names_whitelist, + variable_names_blacklist, + input_meta_graph_def, + input_saved_model_dir, + saved_model_tags.split(","), + checkpoint_version=checkpoint_version) def main(unused_args): @@ -246,7 +261,7 @@ def main(unused_args): FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist, FLAGS.input_meta_graph, FLAGS.input_saved_model_dir, - FLAGS.saved_model_tags) + FLAGS.saved_model_tags, FLAGS.checkpoint_version) if __name__ == "__main__": @@ -268,6 +283,11 @@ if __name__ == "__main__": default="", help="TensorFlow variables file to load.") parser.add_argument( + "--checkpoint_version", + type=int, + default=saver_pb2.SaverDef.V2, + help="Tensorflow variable file format") + parser.add_argument( "--output_graph", type=str, default="", diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py index feeed7102c..91f0061ebc 100644 --- a/tensorflow/python/tools/freeze_graph_test.py +++ b/tensorflow/python/tools/freeze_graph_test.py @@ -84,9 +84,19 @@ class FreezeGraphTest(test_util.TensorFlowTestCase): input_meta_graph = checkpoint_meta_graph_file freeze_graph.freeze_graph( - input_graph_path, input_saver_def_path, input_binary, checkpoint_path, - output_node_names, restore_op_name, filename_tensor_name, - output_graph_path, clear_devices, "", "", input_meta_graph) + input_graph_path, + input_saver_def_path, + input_binary, + checkpoint_path, + output_node_names, + restore_op_name, + filename_tensor_name, + output_graph_path, + clear_devices, + "", + "", + input_meta_graph, + checkpoint_version=saver_write_version) # Now we make sure the variable is now a constant, and that the graph still # produces the expected result. diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py index c2687bf557..9c19271222 100644 --- a/tensorflow/python/tools/optimize_for_inference_lib.py +++ b/tensorflow/python/tools/optimize_for_inference_lib.py @@ -349,6 +349,7 @@ def fold_batch_norms(input_graph_def): bias_add_op.op = "BiasAdd" bias_add_op.name = node.name bias_add_op.attr["T"].CopyFrom(conv_op.attr["T"]) + bias_add_op.attr["data_format"].CopyFrom(conv_op.attr["data_format"]) bias_add_op.input.extend([new_conv_op.name, offset_op.name]) new_ops.extend([scaled_weights_op, new_conv_op, offset_op, bias_add_op]) diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py index 7686bb0f14..084a4500f8 100644 --- a/tensorflow/python/tools/optimize_for_inference_test.py +++ b/tensorflow/python/tools/optimize_for_inference_test.py @@ -173,48 +173,56 @@ class OptimizeForInferenceTest(test.TestCase): self.assertNotEqual("BatchNormWithGlobalNormalization", node.op) def testFoldFusedBatchNorms(self): - with self.test_session() as sess: - inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] - input_op = constant_op.constant( - np.array(inputs), shape=[1, 1, 6, 2], dtype=dtypes.float32) - weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] - weights_op = constant_op.constant( - np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32) - conv_op = nn_ops.conv2d( - input_op, weights_op, [1, 1, 1, 1], padding="SAME", name="conv_op") - mean_op = constant_op.constant( - np.array([10, 20]), shape=[2], dtype=dtypes.float32) - variance_op = constant_op.constant( - np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32) - beta_op = constant_op.constant( - np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) - gamma_op = constant_op.constant( - np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) - ops.get_default_graph().graph_def_versions.producer = 9 - gen_nn_ops._fused_batch_norm( - conv_op, - gamma_op, - beta_op, - mean_op, - variance_op, - 0.00001, - is_training=False, - name="output") - original_graph_def = sess.graph_def - original_result = sess.run(["output:0"]) - optimized_graph_def = optimize_for_inference_lib.fold_batch_norms( - original_graph_def) - - with self.test_session() as sess: - _ = importer.import_graph_def( - optimized_graph_def, input_map={}, name="optimized") - optimized_result = sess.run(["optimized/output:0"]) - - self.assertAllClose( - original_result, optimized_result, rtol=1e-04, atol=1e-06) - - for node in optimized_graph_def.node: - self.assertNotEqual("FusedBatchNorm", node.op) + for data_format, use_gpu in [("NHWC", False), ("NCHW", True)]: + with self.test_session(use_gpu=use_gpu) as sess: + inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] + input_op = constant_op.constant( + np.array(inputs), + shape=[1, 1, 6, 2] if data_format == "NHWC" else [1, 2, 1, 6], + dtype=dtypes.float32) + weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] + weights_op = constant_op.constant( + np.array(weights), shape=[1, 2, 2, 2], dtype=dtypes.float32) + conv_op = nn_ops.conv2d( + input_op, + weights_op, [1, 1, 1, 1], + padding="SAME", + data_format=data_format, + name="conv_op") + mean_op = constant_op.constant( + np.array([10, 20]), shape=[2], dtype=dtypes.float32) + variance_op = constant_op.constant( + np.array([0.25, 0.5]), shape=[2], dtype=dtypes.float32) + beta_op = constant_op.constant( + np.array([0.1, 0.6]), shape=[2], dtype=dtypes.float32) + gamma_op = constant_op.constant( + np.array([1.0, 2.0]), shape=[2], dtype=dtypes.float32) + ops.get_default_graph().graph_def_versions.producer = 9 + gen_nn_ops._fused_batch_norm( + conv_op, + gamma_op, + beta_op, + mean_op, + variance_op, + 0.00001, + is_training=False, + data_format=data_format, + name="output") + original_graph_def = sess.graph_def + original_result = sess.run(["output:0"]) + optimized_graph_def = optimize_for_inference_lib.fold_batch_norms( + original_graph_def) + + with self.test_session(use_gpu=use_gpu) as sess: + _ = importer.import_graph_def( + optimized_graph_def, input_map={}, name="optimized") + optimized_result = sess.run(["optimized/output:0"]) + + self.assertAllClose( + original_result, optimized_result, rtol=1e-04, atol=1e-06) + + for node in optimized_graph_def.node: + self.assertNotEqual("FusedBatchNorm", node.op) def testFuseResizePadAndConv(self): with self.test_session() as sess: diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 667a4b1db8..33f6debbcb 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -31,6 +31,7 @@ import warnings import numpy as np +from six import integer_types from tensorflow.contrib.saved_model.python.saved_model import reader from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils from tensorflow.core.example import example_pb2 @@ -440,7 +441,7 @@ def _create_example_string(example_dict): elif isinstance(feature_list[0], str): example.features.feature[feature_name].bytes_list.value.extend( feature_list) - elif isinstance(feature_list[0], (int, long)): + elif isinstance(feature_list[0], integer_types): example.features.feature[feature_name].int64_list.value.extend( feature_list) else: diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index 17e07e171a..aae757b99a 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -336,7 +336,7 @@ class CheckpointSaverListener(object): `CheckpointSaverHook`, as in this example: ```python - class ExampleCheckpointSaverListerner(CheckpointSaverListener): + class ExampleCheckpointSaverListener(CheckpointSaverListener): def begin(self): # You can add ops to the graph here. print('Starting the session.') @@ -352,7 +352,7 @@ class CheckpointSaverListener(object): print('Done with the session.') ... - listener = ExampleCheckpointSaverListerner() + listener = ExampleCheckpointSaverListener() saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir, listeners=[listener]) with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]): diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 992184ec9e..bd9985a7c5 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -58,6 +58,8 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map def match_filenames_once(pattern, name=None): """Save the list of files matching pattern, so it is only computed once. + NOTE: The order of the files returned can be non-deterministic. + Args: pattern: A file pattern (glob), or 1D tensor of file patterns. name: A name for the operations (optional). diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 3888e9bba4..0c1c8e664b 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -1597,9 +1597,9 @@ class Saver(object): [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: - A string: path prefix used for the checkpoint files. If the saver is - sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' - is the number of shards created. + A string: path prefix used for the checkpoint files. If checkpoint + format is V1 and the saver is sharded, this string ends with: + '-?????-of-nnnnn' where 'nnnnn' is the number of shards created. If the saver is empty, returns None. Raises: @@ -1749,6 +1749,12 @@ class Saver(object): return if save_path is None: raise ValueError("Can't load save_path when it is None.") + if (os.path.isfile(save_path) and + self._write_version not in ( + saver_pb2.SaverDef.V1, saver_pb2.SaverDef.LEGACY)): + raise ValueError("The specified path: %s is a file." + " Please specify only the path prefix" + " to the checkpoint files." % save_path) logging.info("Restoring parameters from %s", save_path) if context.in_graph_mode(): sess.run(self.saver_def.restore_op_name, diff --git a/tensorflow/python/util/compat_internal.py b/tensorflow/python/util/compat_internal.py new file mode 100644 index 0000000000..fee1d6fab7 --- /dev/null +++ b/tensorflow/python/util/compat_internal.py @@ -0,0 +1,34 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for Python 2 vs. 3 compatibility that are private to TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def path_to_str(path): + """Returns the file system path representation of a `PathLike` object, + else as it is. + + Args: + path: An object that can be converted to path representation. + + Returns: + A `str` object. + """ + if hasattr(path, "__fspath__"): + path = as_str_any(path.__fspath__()) + return path |