aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-08-28 16:12:57 -0700
committerGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-08-28 16:12:57 -0700
commit757538bd14f24de3d7bf654a03c6543bb06a8e75 (patch)
tree4873885feca3a3e5787241477ef8d1333c494d1e /tensorflow/python
parent6b25c37daaa6a063b6b687252343db5453a84b8b (diff)
parent7f52de1a2b03568dc98ad51685b56661a5105da6 (diff)
Merge branch 'master' into avijit/add-cpu-backend
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD33
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py20
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py17
-rw-r--r--tensorflow/python/data/util/nest.py37
-rw-r--r--tensorflow/python/data/util/nest_test.py27
-rw-r--r--tensorflow/python/debug/BUILD2
-rw-r--r--tensorflow/python/distribute/BUILD39
-rw-r--r--tensorflow/python/distribute/distribute_config.py45
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py120
-rw-r--r--tensorflow/python/distribute/distribute_coordinator_test.py64
-rw-r--r--tensorflow/python/distribute/estimator_training.py264
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/context.py17
-rw-r--r--tensorflow/python/eager/core_test.py13
-rw-r--r--tensorflow/python/eager/execution_callbacks.py8
-rw-r--r--tensorflow/python/eager/function.py359
-rw-r--r--tensorflow/python/eager/function_test.py243
-rwxr-xr-x[-rw-r--r--]tensorflow/python/eager/pywrap_tfe.h2
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc4
-rw-r--r--tensorflow/python/estimator/BUILD3
-rw-r--r--tensorflow/python/estimator/canned/head.py4
-rw-r--r--tensorflow/python/estimator/estimator.py81
-rw-r--r--tensorflow/python/estimator/export/export.py45
-rw-r--r--tensorflow/python/estimator/export/export_test.py30
-rw-r--r--tensorflow/python/estimator/keras.py11
-rw-r--r--tensorflow/python/estimator/run_config.py29
-rw-r--r--tensorflow/python/estimator/training.py22
-rw-r--r--tensorflow/python/estimator/util.py8
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py74
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py550
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py764
-rw-r--r--tensorflow/python/framework/device.py38
-rw-r--r--tensorflow/python/framework/function.py23
-rw-r--r--tensorflow/python/framework/function_def_to_graph.py20
-rw-r--r--tensorflow/python/framework/function_def_to_graph_test.py45
-rw-r--r--tensorflow/python/framework/ops.py64
-rw-r--r--tensorflow/python/framework/ops_enable_eager_test.py38
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc2
-rw-r--r--tensorflow/python/framework/smart_cond.py6
-rw-r--r--tensorflow/python/framework/subscribe.py7
-rw-r--r--tensorflow/python/framework/test_util.py2
-rw-r--r--tensorflow/python/framework/test_util_test.py28
-rw-r--r--tensorflow/python/grappler/graph_analyzer.i26
-rw-r--r--tensorflow/python/grappler/graph_analyzer.py46
-rwxr-xr-xtensorflow/python/keras/BUILD2
-rw-r--r--tensorflow/python/keras/activations_test.py20
-rw-r--r--tensorflow/python/keras/applications/__init__.py51
-rw-r--r--tensorflow/python/keras/applications/applications_test.py8
-rw-r--r--tensorflow/python/keras/applications/densenet.py47
-rw-r--r--tensorflow/python/keras/applications/imagenet_utils.py33
-rw-r--r--tensorflow/python/keras/applications/inception_resnet_v2.py26
-rw-r--r--tensorflow/python/keras/applications/inception_v3.py25
-rw-r--r--tensorflow/python/keras/applications/mobilenet.py25
-rw-r--r--tensorflow/python/keras/applications/mobilenet_v2.py24
-rw-r--r--tensorflow/python/keras/applications/nasnet.py35
-rw-r--r--tensorflow/python/keras/applications/resnet50.py24
-rw-r--r--tensorflow/python/keras/applications/vgg16.py24
-rw-r--r--tensorflow/python/keras/applications/vgg19.py24
-rw-r--r--tensorflow/python/keras/applications/xception.py25
-rw-r--r--tensorflow/python/keras/backend.py5
-rw-r--r--tensorflow/python/keras/backend_test.py87
-rw-r--r--tensorflow/python/keras/constraints_test.py8
-rw-r--r--tensorflow/python/keras/engine/base_layer.py26
-rw-r--r--tensorflow/python/keras/engine/sequential.py1
-rw-r--r--tensorflow/python/keras/engine/training.py13
-rw-r--r--tensorflow/python/keras/engine/training_test.py4
-rw-r--r--tensorflow/python/keras/initializers_test.py28
-rw-r--r--tensorflow/python/keras/integration_test.py22
-rw-r--r--tensorflow/python/keras/layers/advanced_activations_test.py18
-rw-r--r--tensorflow/python/keras/layers/convolutional_recurrent_test.py12
-rw-r--r--tensorflow/python/keras/layers/core_test.py20
-rw-r--r--tensorflow/python/keras/layers/embeddings_test.py2
-rw-r--r--tensorflow/python/keras/layers/local_test.py8
-rw-r--r--tensorflow/python/keras/layers/merge_test.py6
-rw-r--r--tensorflow/python/keras/layers/noise_test.py4
-rw-r--r--tensorflow/python/keras/layers/normalization.py14
-rw-r--r--tensorflow/python/keras/layers/normalization_test.py18
-rw-r--r--tensorflow/python/keras/layers/recurrent.py61
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py58
-rw-r--r--tensorflow/python/keras/layers/wrappers.py26
-rw-r--r--tensorflow/python/keras/layers/wrappers_test.py38
-rw-r--r--tensorflow/python/keras/losses_test.py8
-rw-r--r--tensorflow/python/keras/metrics.py8
-rw-r--r--tensorflow/python/keras/metrics_test.py12
-rw-r--r--tensorflow/python/keras/models.py3
-rw-r--r--tensorflow/python/keras/models_test.py48
-rw-r--r--tensorflow/python/keras/preprocessing/__init__.py2
-rw-r--r--tensorflow/python/keras/preprocessing/image.py492
-rw-r--r--tensorflow/python/keras/preprocessing/sequence.py63
-rw-r--r--tensorflow/python/keras/regularizers_test.py4
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py17
-rw-r--r--tensorflow/python/kernel_tests/BUILD5
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py10
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py146
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py31
-rw-r--r--tensorflow/python/kernel_tests/constant_op_eager_test.py2
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py95
-rw-r--r--tensorflow/python/kernel_tests/ctc_decoder_ops_test.py18
-rw-r--r--tensorflow/python/kernel_tests/distributions/categorical_test.py42
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py56
-rw-r--r--tensorflow/python/kernel_tests/distributions/identity_bijector_test.py2
-rw-r--r--tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py2
-rw-r--r--tensorflow/python/kernel_tests/distributions/multinomial_test.py46
-rw-r--r--tensorflow/python/kernel_tests/embedding_ops_test.py20
-rw-r--r--tensorflow/python/kernel_tests/extract_image_patches_grad_test.py20
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py21
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py4
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py16
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py14
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py46
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py4
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py2
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_test.py10
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py44
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py12
-rw-r--r--tensorflow/python/kernel_tests/matrix_logarithm_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/partitioned_variables_test.py80
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py36
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py17
-rw-r--r--tensorflow/python/kernel_tests/stack_op_test.py30
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py14
-rw-r--r--tensorflow/python/lib/io/file_io.i2
-rw-r--r--tensorflow/python/ops/array_grad.py81
-rw-r--r--tensorflow/python/ops/array_ops.py68
-rw-r--r--tensorflow/python/ops/check_ops.py49
-rw-r--r--tensorflow/python/ops/collective_ops_test.py3
-rw-r--r--tensorflow/python/ops/cond_v2.py2
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py148
-rw-r--r--tensorflow/python/ops/control_flow_ops.py10
-rw-r--r--tensorflow/python/ops/embedding_ops.py7
-rw-r--r--tensorflow/python/ops/lookup_ops.py11
-rw-r--r--tensorflow/python/ops/math_ops.py53
-rw-r--r--tensorflow/python/ops/nn_grad.py22
-rw-r--r--tensorflow/python/ops/nn_ops.py2
-rw-r--r--tensorflow/python/ops/parsing_ops.py178
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py28
-rw-r--r--tensorflow/python/ops/sparse_ops.py63
-rw-r--r--tensorflow/python/ops/sparse_ops_test.py32
-rw-r--r--tensorflow/python/ops/variable_scope.py41
-rw-r--r--tensorflow/python/ops/variables.py55
-rwxr-xr-x[-rw-r--r--]tensorflow/python/pywrap_tfe.i15
-rw-r--r--tensorflow/python/saved_model/utils_impl.py2
-rw-r--r--tensorflow/python/tensorflow.i1
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl1
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files_v1.bzl1
-rw-r--r--tensorflow/python/tools/freeze_graph.py64
-rw-r--r--tensorflow/python/training/adam.py8
-rw-r--r--tensorflow/python/training/checkpoint_management.py18
-rw-r--r--tensorflow/python/training/checkpoint_management_test.py47
-rw-r--r--tensorflow/python/training/checkpointable/util.py72
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py32
-rw-r--r--tensorflow/python/training/moving_averages.py55
-rw-r--r--tensorflow/python/training/moving_averages_test.py21
-rw-r--r--tensorflow/python/training/optimizer.py7
-rw-r--r--tensorflow/python/training/saver_test.py2
-rw-r--r--tensorflow/python/util/tf_export.py13
-rw-r--r--tensorflow/python/util/util.cc43
-rw-r--r--tensorflow/python/util/util.h16
-rw-r--r--tensorflow/python/util/util.i3
162 files changed, 4849 insertions, 2146 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e1d3422730..5af6437c56 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -134,6 +134,7 @@ py_library(
"//tensorflow/core:protos_all_py",
"//tensorflow/python/compat",
"//tensorflow/python/data",
+ "//tensorflow/python/distribute:estimator_training",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/keras",
"//tensorflow/python/ops/distributions",
@@ -723,7 +724,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":array_ops",
- ":cond_v2_impl",
":dtypes",
":framework_ops",
":graph_to_function_def",
@@ -1348,6 +1348,19 @@ py_test(
)
py_test(
+ name = "framework_ops_enable_eager_test",
+ size = "small",
+ srcs = ["framework/ops_enable_eager_test.py"],
+ main = "framework/ops_enable_eager_test.py",
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework",
+ ":platform_test",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_test(
name = "framework_tensor_shape_test",
size = "small",
srcs = ["framework/tensor_shape_test.py"],
@@ -2620,8 +2633,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":constant_op",
+ ":dtypes",
":framework_test_lib",
":sparse_ops",
+ ":sparse_tensor",
],
)
@@ -3245,7 +3260,6 @@ py_library(
),
srcs_version = "PY2AND3",
deps = [
- "saver",
":array_ops",
":array_ops_gen",
":checkpoint_management",
@@ -3269,6 +3283,7 @@ py_library(
":random_ops",
":resource_variable_ops",
":resources",
+ ":saver",
":sdca_ops",
":session",
":sparse_ops",
@@ -3762,6 +3777,7 @@ tf_py_wrap_cc(
"framework/python_op_gen.i",
"grappler/cluster.i",
"grappler/cost_analyzer.i",
+ "grappler/graph_analyzer.i",
"grappler/item.i",
"grappler/model_analyzer.i",
"grappler/tf_optimizer.i",
@@ -3820,6 +3836,7 @@ tf_py_wrap_cc(
"//tensorflow/core/grappler/clusters:single_machine",
"//tensorflow/core/grappler/clusters:virtual_cluster",
"//tensorflow/core/grappler/costs:graph_memory",
+ "//tensorflow/core/grappler/graph_analyzer:graph_analyzer_tool",
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//tensorflow/core:lib",
"//tensorflow/core:reader_base",
@@ -5521,6 +5538,18 @@ py_test(
],
)
+py_binary(
+ name = "graph_analyzer",
+ srcs = [
+ "grappler/graph_analyzer.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":framework_for_generated_wrappers",
+ ":pywrap_tensorflow_internal",
+ ],
+)
+
pyx_library(
name = "framework_fast_tensor_util",
srcs = ["framework/fast_tensor_util.pyx"],
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index a4dd5467c0..74b001a572 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 22)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 8, 28)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index e16aa82d4d..159218c99b 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -110,8 +110,24 @@ class ConcatenateDatasetTest(test.TestCase):
dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
to_concatenate_components)
- with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ with self.assertRaisesRegexp(TypeError, "have different types"):
+ input_dataset.concatenate(dataset_to_concatenate)
+
+ def testConcatenateDatasetDifferentKeys(self):
+ input_components = {
+ "foo": np.array([[1], [2], [3], [4]]),
+ "bar": np.array([[12], [13], [14], [15]])
+ }
+ to_concatenate_components = {
+ "foo": np.array([[1], [2], [3], [4]]),
+ "baz": np.array([[5], [6], [7], [8]])
+ }
+
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
+ dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
+ to_concatenate_components)
+
+ with self.assertRaisesRegexp(TypeError, "have different types"):
input_dataset.concatenate(dataset_to_concatenate)
def testConcatenateDatasetDifferentType(self):
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index fdab8abfae..8c37b1871b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -1684,15 +1684,14 @@ class ConcatenateDataset(Dataset):
super(ConcatenateDataset, self).__init__()
self._input_dataset = input_dataset
self._dataset_to_concatenate = dataset_to_concatenate
- nest.assert_same_structure(input_dataset.output_types,
- dataset_to_concatenate.output_types)
- for a, b in zip(
- nest.flatten(input_dataset.output_types),
- nest.flatten(dataset_to_concatenate.output_types)):
- if a != b:
- raise TypeError(
- "Two datasets to concatenate have different types %s and %s" %
- (input_dataset.output_types, dataset_to_concatenate.output_types))
+ if input_dataset.output_types != dataset_to_concatenate.output_types:
+ raise TypeError(
+ "Two datasets to concatenate have different types %s and %s" %
+ (input_dataset.output_types, dataset_to_concatenate.output_types))
+ if input_dataset.output_classes != dataset_to_concatenate.output_classes:
+ raise TypeError(
+ "Two datasets to concatenate have different classes %s and %s" %
+ (input_dataset.output_classes, dataset_to_concatenate.output_classes))
def _as_variant_tensor(self):
# pylint: disable=protected-access
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index 1b596bdfc0..9d621fcd30 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -129,35 +129,18 @@ def flatten(nest):
return _pywrap_tensorflow.FlattenForData(nest)
-def _recursive_assert_same_structure(nest1, nest2, check_types):
- is_sequence_nest1 = is_sequence(nest1)
- if is_sequence_nest1 != is_sequence(nest2):
- raise ValueError(
- "The two structures don't have the same nested structure. "
- "First structure: %s, second structure: %s." % (nest1, nest2))
-
- if is_sequence_nest1:
- type_nest1 = type(nest1)
- type_nest2 = type(nest2)
- if check_types and type_nest1 != type_nest2:
- raise TypeError(
- "The two structures don't have the same sequence type. First "
- "structure has type %s, while second structure has type %s."
- % (type_nest1, type_nest2))
-
- for n1, n2 in zip(_yield_value(nest1), _yield_value(nest2)):
- _recursive_assert_same_structure(n1, n2, check_types)
-
-
def assert_same_structure(nest1, nest2, check_types=True):
"""Asserts that two structures are nested in the same way.
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
- check_types: if `True` (default) types of sequences are checked as
- well. If set to `False`, for example a list and a tuple of objects will
- look same if they have the same size.
+ check_types: if `True` (default) types of sequences should be same as
+ well. For dictionary, "type" of dictionary is considered to include its
+ keys. In other words, two dictionaries with different keys are considered
+ to have a different "type". If set to `False`, two iterables are
+ considered same as long as they yield the elements that have same
+ structures.
Raises:
ValueError: If the two structures do not have the same number of elements or
@@ -165,13 +148,7 @@ def assert_same_structure(nest1, nest2, check_types=True):
TypeError: If the two structures differ in the type of sequence in any of
their substructures. Only possible if `check_types` is `True`.
"""
- len_nest1 = len(flatten(nest1)) if is_sequence(nest1) else 1
- len_nest2 = len(flatten(nest2)) if is_sequence(nest2) else 1
- if len_nest1 != len_nest2:
- raise ValueError("The two structures don't have the same number of "
- "elements. First structure: %s, second structure: %s."
- % (nest1, nest2))
- _recursive_assert_same_structure(nest1, nest2, check_types)
+ _pywrap_tensorflow.AssertSameStructureForData(nest1, nest2, check_types)
def _packed_nest_with_indices(structure, flat, index):
diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py
index ff380815a4..616aa9f551 100644
--- a/tensorflow/python/data/util/nest_test.py
+++ b/tensorflow/python/data/util/nest_test.py
@@ -163,21 +163,30 @@ class NestTest(test.TestCase):
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
structure_different_num_elements = ("spam", "eggs")
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
+ structure_dictionary = {"foo": 2, "bar": 4, "baz": {"foo": 5, "bar": 6}}
+ structure_dictionary_diff_nested = {
+ "foo": 2,
+ "bar": 4,
+ "baz": {
+ "foo": 5,
+ "baz": 6
+ }
+ }
nest.assert_same_structure(structure1, structure2)
nest.assert_same_structure("abc", 1.0)
nest.assert_same_structure("abc", np.array([0, 1]))
nest.assert_same_structure("abc", constant_op.constant([0, 1]))
with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ "don't have the same nested structure"):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ "don't have the same nested structure"):
nest.assert_same_structure((0, 1), np.array([0, 1]))
with self.assertRaisesRegexp(ValueError,
- "don't have the same number of elements"):
+ "don't have the same nested structure"):
nest.assert_same_structure(0, (0, 1))
with self.assertRaisesRegexp(ValueError,
@@ -203,11 +212,23 @@ class NestTest(test.TestCase):
nest.assert_same_structure(((3,), 4), (3, (4,)))
structure1_list = {"a": ((1, 2), 3), "b": 4, "c": (5, 6)}
+ structure2_list = {"a": ((1, 2), 3), "b": 4, "d": (5, 6)}
with self.assertRaisesRegexp(TypeError,
"don't have the same sequence type"):
nest.assert_same_structure(structure1, structure1_list)
nest.assert_same_structure(structure1, structure2, check_types=False)
nest.assert_same_structure(structure1, structure1_list, check_types=False)
+ with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
+ nest.assert_same_structure(structure1_list, structure2_list)
+ with self.assertRaisesRegexp(ValueError, "don't have the same set of keys"):
+ nest.assert_same_structure(structure_dictionary,
+ structure_dictionary_diff_nested)
+ nest.assert_same_structure(
+ structure_dictionary,
+ structure_dictionary_diff_nested,
+ check_types=False)
+ nest.assert_same_structure(
+ structure1_list, structure2_list, check_types=False)
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 8a4ac6aaef..55d2709845 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -576,7 +576,6 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_windows",
- "nomac",
"oss_serial",
],
deps = [
@@ -1047,7 +1046,6 @@ cuda_py_test(
tags = [
"no_oss", # Incompatible with bazel_pip.
"no_windows",
- "nomac", # TODO(cais): Install of futures and grpcio on all macs.
"notsan",
],
)
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 98ef9bf492..a081c30781 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -9,6 +9,25 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
+ name = "distribute",
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":distribute_config",
+ ":distribute_coordinator",
+ ":distribute_coordinator_context",
+ ],
+)
+
+py_library(
+ name = "distribute_config",
+ srcs = [
+ "distribute_config.py",
+ ],
+ deps = [],
+)
+
+py_library(
name = "distribute_coordinator",
srcs = [
"distribute_coordinator.py",
@@ -25,7 +44,11 @@ py_test(
size = "large",
srcs = ["distribute_coordinator_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "manual",
+ "no_pip",
+ "notap",
+ ],
deps = [
":distribute_coordinator",
"//tensorflow/core:protos_all_py",
@@ -81,3 +104,17 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
+
+# Used only by estimator.
+py_library(
+ name = "estimator_training",
+ srcs = [
+ "estimator_training.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":distribute_coordinator",
+ ":distribute_coordinator_context",
+ "//tensorflow/python:training",
+ ],
+)
diff --git a/tensorflow/python/distribute/distribute_config.py b/tensorflow/python/distribute/distribute_config.py
new file mode 100644
index 0000000000..fac35742fe
--- /dev/null
+++ b/tensorflow/python/distribute/distribute_config.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A configure tuple for high-level APIs for running distribution strategies."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+
+class DistributeConfig(
+ collections.namedtuple(
+ 'DistributeConfig',
+ ['train_distribute', 'eval_distribute', 'remote_cluster'])):
+ """A config tuple for distribution strategies.
+
+ Attributes:
+ train_distribute: a `DistributionStrategy` object for training.
+ eval_distribute: an optional `DistributionStrategy` object for
+ evaluation.
+ remote_cluster: a dict, `ClusterDef` or `ClusterSpec` object specifying
+ the cluster configurations. If this is given, the `train_and_evaluate`
+ method will be running as a standalone client which connects to the
+ cluster for training.
+ """
+
+ def __new__(cls,
+ train_distribute=None,
+ eval_distribute=None,
+ remote_cluster=None):
+ return super(DistributeConfig, cls).__new__(cls, train_distribute,
+ eval_distribute, remote_cluster)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index eb081b65fc..46cdd64a6e 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -22,9 +22,12 @@ import copy
import json
import os
import threading
+import time
from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.client import session
from tensorflow.python.distribute import distribute_coordinator_context
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
@@ -311,7 +314,11 @@ def _run_single_worker(worker_fn,
worker_barrier=None):
"""Runs a single worker by calling `worker_fn` under context."""
strategy = copy.deepcopy(strategy)
- strategy.configure(session_config, cluster_spec, task_type, task_id)
+ # If there is an EVALUATOR task, we run single-machine eval on that task.
+ if task_type == _TaskType.EVALUATOR:
+ strategy.configure(session_config)
+ else:
+ strategy.configure(session_config, cluster_spec, task_type, task_id)
context = _WorkerContext(
strategy,
cluster_spec,
@@ -328,26 +335,48 @@ def _run_std_server(cluster_spec=None,
task_type=None,
task_id=None,
session_config=None,
- rpc_layer=None):
+ rpc_layer=None,
+ environment=None):
"""Runs a standard server."""
- server = server_lib.Server(
- cluster_spec,
- job_name=task_type,
- task_index=task_id,
- config=session_config,
- protocol=rpc_layer)
- server.start()
- return server
-
-def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
- rpc_layer):
+ class _FakeServer(object):
+ """A fake server that runs a master session."""
+
+ def start(self):
+ assert cluster_spec
+ target = cluster_spec.task_address(task_type, task_id)
+ if rpc_layer:
+ target = rpc_layer + "://" + target
+ # A tensorflow server starts when a remote session is created.
+ session.Session(target=target, config=session_config)
+
+ def join(self):
+ while True:
+ time.sleep(5)
+
+ if environment == "google":
+ server = _FakeServer()
+ server.start()
+ return server
+ else:
+ server = server_lib.Server(
+ cluster_spec,
+ job_name=task_type,
+ task_index=task_id,
+ config=session_config,
+ protocol=rpc_layer)
+ server.start()
+ return server
+
+
+def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for between-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ args=(eval_fn, eval_strategy, None, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@@ -378,14 +407,14 @@ def _run_between_graph_client(worker_fn, strategy, cluster_spec, session_config,
eval_thread.join()
-def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
- rpc_layer):
+def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer):
"""Runs a standalone client for in-graph replication."""
eval_thread = None
if _TaskType.EVALUATOR in cluster_spec.jobs:
eval_thread = threading.Thread(
target=_run_single_worker,
- args=(worker_fn, strategy, cluster_spec, _TaskType.EVALUATOR, 0,
+ args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
session_config),
kwargs={
"rpc_layer": rpc_layer,
@@ -408,6 +437,8 @@ def _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
# is the special task when we support cluster_spec propagation.
def run_distribute_coordinator(worker_fn,
strategy,
+ eval_fn=None,
+ eval_strategy=None,
mode=CoordinatorMode.STANDALONE_CLIENT,
cluster_spec=None,
task_type=None,
@@ -488,10 +519,12 @@ def run_distribute_coordinator(worker_fn,
If `cluster_spec` is not given in any format, it becomes local training and
this coordinator will connect to a local session.
- For evaluation, if "evaluator" exist in the cluster_spec, a separate thread
- will be created with its `task_type` set to "evaluator". If "evaluator" is not
- set in the cluster_spec, it entirely depends on the `worker_fn` for how to do
- evaluation.
+ For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
+ will be created to call `eval_fn` with its `task_type` set to "evaluator". If
+ `eval_fn` is not defined, fall back to `worker_fn`. This implies that
+ evaluation will be done on a single machine if there is an "evaluator" task.
+ If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the
+ `worker_fn` for how to do evaluation.
Args:
worker_fn: the function to be called. The function should accept a
@@ -501,6 +534,8 @@ def run_distribute_coordinator(worker_fn,
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
`cluster_spc`, `task_type` and `task_id`.
+ eval_fn: optional function for "evaluator" task.
+ eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
in a cluster. If not set or empty, fall back to local training.
@@ -531,32 +566,59 @@ def run_distribute_coordinator(worker_fn,
"`tf.train.ClusterDef` object")
# TODO(yuefengz): validate cluster_spec.
+ rpc_layer = tf_config.get("rpc_layer", rpc_layer)
+ environment = tf_config.get("environment", None)
+
+ if cluster_spec:
+ logging.info(
+ "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
+ "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
+ cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
+
if not cluster_spec:
# `mode` is ignored in the local case.
+ logging.info("Running local Distribute Coordinator.")
_run_single_worker(worker_fn, strategy, None, None, None, session_config,
rpc_layer)
+ if eval_fn:
+ _run_single_worker(eval_fn, eval_strategy or strategy, None, None, None,
+ session_config, rpc_layer)
elif mode == CoordinatorMode.STANDALONE_CLIENT:
+ eval_fn = eval_fn or worker_fn
+ eval_strategy = eval_strategy or strategy
+
# The client must know the cluster but servers in the cluster don't have to
# know the client.
if task_type in [_TaskType.CLIENT, None]:
if strategy.between_graph:
- _run_between_graph_client(worker_fn, strategy, cluster_spec,
- session_config, rpc_layer)
+ _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer)
else:
- _run_in_graph_client(worker_fn, strategy, cluster_spec, session_config,
- rpc_layer)
+ _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
+ cluster_spec, session_config, rpc_layer)
else:
# If not a client job, run the standard server.
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ rpc_layer=rpc_layer,
+ environment=environment)
server.join()
else:
if mode != CoordinatorMode.INDEPENDENT_WORKER:
raise ValueError("Unexpected coordinator mode: %r" % mode)
+ eval_fn = eval_fn or worker_fn
+ eval_strategy = eval_strategy or strategy
+
# Every one starts a standard server.
server = _run_std_server(
- cluster_spec=cluster_spec, task_type=task_type, task_id=task_id)
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ rpc_layer=rpc_layer,
+ environment=environment)
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if strategy.between_graph:
@@ -572,8 +634,8 @@ def run_distribute_coordinator(worker_fn,
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
- _run_single_worker(worker_fn, strategy, cluster_spec, task_type, task_id,
- session_config, rpc_layer)
+ _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index 97c6bdd15a..5dd57fa134 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
import contextlib
import copy
+import json
import os
import sys
+import time
import threading
import six
@@ -59,6 +61,8 @@ INDEPENDENT_WORKER = distribute_coordinator.CoordinatorMode.INDEPENDENT_WORKER
NUM_WORKERS = 3
NUM_PS = 2
+original_sys_exit = sys.exit
+
def _bytes_to_str(maybe_bytes):
if isinstance(maybe_bytes, six.string_types):
@@ -369,7 +373,8 @@ class DistributeCoordinatorTestBase(test.TestCase):
cluster_spec=None,
task_type=None,
task_id=None,
- rpc_layer=None):
+ rpc_layer=None,
+ environment=None):
task_type = str(task_type)
task_id = task_id or 0
with self._lock:
@@ -730,6 +735,63 @@ class DistributeCoordinatorTestInpendentWorkerMode(
self.assertTrue(self._std_servers[WORKER][2].joined)
self.assertFalse(self._std_servers[EVALUATOR][0].joined)
+ def testRunStdServerInGoogleEnvironment(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["localhost:0"]}
+ tf_config = {"cluster": cluster_spec, "environment": "google"}
+
+ joined = [False]
+
+ def _fake_sleep(_):
+ joined[0] = True
+ original_sys_exit(0)
+
+ def _thread_fn(cluster_spec):
+ distribute_coordinator.run_distribute_coordinator(
+ None,
+ None,
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="ps",
+ task_id=0)
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ time, "sleep", _fake_sleep):
+ t = threading.Thread(target=_thread_fn, args=(cluster_spec,))
+ t.start()
+ t.join()
+ self.assertTrue(joined[0])
+
+ def testRpcLayerEnvironmentVariable(self):
+ cluster_spec = {"worker": ["fake_worker"], "ps": ["fake_ps"]}
+ tf_config = {"cluster": cluster_spec, "rpc_layer": "cake"}
+
+ rpc_layer_from_coordinator = [None]
+
+ def _run_mock_server(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ session_config=None,
+ rpc_layer=None,
+ environment=None):
+ del cluster_spec, task_type, task_id, session_config, environment
+ rpc_layer_from_coordinator[0] = rpc_layer
+ return MockServer()
+
+ with test.mock.patch.dict(
+ "os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}), test.mock.patch.object(
+ distribute_coordinator, "_run_std_server", _run_mock_server):
+ distribute_coordinator.run_distribute_coordinator(
+ None,
+ None,
+ mode=INDEPENDENT_WORKER,
+ cluster_spec=cluster_spec,
+ task_type="ps",
+ task_id=0)
+ self.assertEqual(rpc_layer_from_coordinator[0], "cake")
+
if __name__ == "__main__":
# TODO(yuefengz): find a smart way to terminite std server threads.
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
new file mode 100644
index 0000000000..202e19c420
--- /dev/null
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -0,0 +1,264 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Training utilities for Estimator to use Distribute Coordinator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import six
+
+from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import distribute_coordinator_context as dc_context
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import server_lib
+
+# pylint: disable=protected-access
+CHIEF = dc._TaskType.CHIEF
+EVALUATOR = dc._TaskType.EVALUATOR
+PS = dc._TaskType.PS
+WORKER = dc._TaskType.WORKER
+
+# pylint: enable=protected-access
+
+
+def _count_ps(cluster_spec):
+ """Counts the number of parameter servers in cluster_spec."""
+ if not cluster_spec:
+ raise RuntimeError(
+ 'Internal error: `_count_ps` does not expect empty cluster_spec.')
+
+ return len(cluster_spec.as_dict().get(PS, []))
+
+
+def _count_worker(cluster_spec, chief_task_type):
+ """Counts the number of workers (including chief) in cluster_spec."""
+ if not cluster_spec:
+ raise RuntimeError(
+ 'Internal error: `_count_worker` does not expect empty cluster_spec.')
+
+ return (len(cluster_spec.as_dict().get(WORKER, [])) + len(
+ cluster_spec.as_dict().get(chief_task_type, [])))
+
+
+def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
+ """Returns the global id of the given task type in a cluster."""
+ if not task_type:
+ return 0
+
+ # Sort task names in cluster by "chief"/"master", "evaluator", "worker"
+ # and "ps". More details can be found at the documentation of
+ # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ task_type_ordered_list = []
+ if chief_task_type in cluster_spec.jobs:
+ task_type_ordered_list = [chief_task_type]
+ task_type_ordered_list.extend([
+ t for t in sorted(cluster_spec.jobs) if t != chief_task_type and t != PS
+ ])
+ if PS in cluster_spec.jobs:
+ task_type_ordered_list.append(PS)
+
+ # Find the right gloabl_id for current task.
+ next_global_id = 0
+ for t in task_type_ordered_list:
+ if t == task_type:
+ return next_global_id + task_id
+ # `cluster_spec.job_tasks` returns all task addresses of type `t`.
+ next_global_id += len(cluster_spec.job_tasks(t))
+
+ # It is unexpected that it passes through all task_types in
+ # `task_type_ordered_list`.
+ raise RuntimeError('Internal Error: `task_type` ({}) is not in '
+ 'cluster_spec ({}).'.format(task_type, cluster_spec))
+
+
+def _init_run_config_from_worker_context(config, worker_context):
+ """Initializes run config from distribute coordinator's worker context."""
+
+ # pylint: disable=protected-access
+ config._service = None
+ config._cluster_spec = worker_context.cluster_spec
+ config._task_type = worker_context.task_type
+ config._task_id = worker_context.task_id
+ config._evaluation_master = worker_context.master_target
+ config._master = worker_context.master_target
+ config._is_chief = worker_context.is_chief
+
+ if config._cluster_spec:
+ # Distributed mode.
+ if config._task_type != EVALUATOR:
+
+ config._num_ps_replicas = _count_ps(config._cluster_spec)
+ config._num_worker_replicas = _count_worker(
+ config._cluster_spec, chief_task_type=CHIEF)
+ config._global_id_in_cluster = _get_global_id(
+ config._cluster_spec,
+ config._task_type,
+ config._task_id,
+ chief_task_type=CHIEF)
+ else:
+ # Evaluator task should not be aware of the other tasks.
+ config._cluster_spec = server_lib.ClusterSpec({})
+ config._num_ps_replicas = 0
+ config._num_worker_replicas = 0
+ config._global_id_in_cluster = None # undefined
+ else:
+ # Local mode.
+ config._global_id_in_cluster = 0
+ config._num_ps_replicas = 0
+ config._num_worker_replicas = 1
+
+
+def init_run_config(config, tf_config):
+ """Initializes RunConfig for distribution strategies."""
+ # pylint: disable=protected-access
+ if (config._experimental_distribute and
+ config._experimental_distribute.train_distribute):
+ if config._train_distribute:
+ raise ValueError('Either `train_distribute` or'
+ '`experimental_distribute.train_distribute` can be set.')
+ config._train_distribute = config._experimental_distribute.train_distribute
+
+ if (config._experimental_distribute and
+ config._experimental_distribute.eval_distribute):
+ if config._eval_distribute:
+ raise ValueError('Either `eval_distribute` or'
+ '`experimental_distribute.eval_distribute` can be set.')
+ config._eval_distribute = config._experimental_distribute.eval_distribute
+
+ cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
+ config._init_distributed_setting_from_environment_var({})
+
+ # Use distribute coordinator with STANDALONE_CLIENT mode if
+ # `experimental_distribute.remote_cluster` is set.
+ if (config._train_distribute and config._experimental_distribute and
+ config._experimental_distribute.remote_cluster):
+ if tf_config:
+ raise ValueError('Cannot set both TF_CONFIG environment variable and '
+ '`experimental_distribute.remote_cluster`')
+ config._distribute_coordinator_mode = dc.CoordinatorMode.STANDALONE_CLIENT
+ config._cluster_spec = config._experimental_distribute.remote_cluster
+ logging.info('RunConfig initialized for Distribute Coordinator with '
+ 'STANDALONE_CLIENT mode')
+ return
+
+ # Don't use distribute coordinator if it is local training or cluster has a
+ # MASTER job or `train_distribute` is not specifed.
+ if (not tf_config or 'master' in cluster_spec.jobs or
+ not config._train_distribute):
+ config._distribute_coordinator_mode = None
+ config._init_distributed_setting_from_environment_var(tf_config)
+ config._maybe_overwrite_session_config_for_distributed_training()
+ logging.info('Not using Distribute Coordinator.')
+ return
+
+ # Use distribute coordinator with INDEPENDENT_WORKER mode otherwise.
+ assert tf_config
+
+ # Set the cluster_spec only since the distributed setting will come from
+ # distribute coordinator.
+ config._cluster_spec = cluster_spec
+ config._distribute_coordinator_mode = dc.CoordinatorMode.INDEPENDENT_WORKER
+ logging.info('RunConfig initialized for Distribute Coordinator with '
+ 'INDEPENDENT_WORKER mode')
+
+
+def should_run_distribute_coordinator(config):
+ """Checks the config to see whether to run distribute coordinator."""
+ # pylint: disable=protected-access
+ if (not hasattr(config, '_distribute_coordinator_mode') or
+ config._distribute_coordinator_mode is None):
+ return False
+ if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
+ config._distribute_coordinator_mode not in [
+ dc.CoordinatorMode.STANDALONE_CLIENT,
+ dc.CoordinatorMode.INDEPENDENT_WORKER
+ ]):
+ logging.warning('Unexpected distribute_coordinator_mode: %r',
+ config._distribute_coordinator_mode)
+ return False
+ if not config.cluster_spec:
+ logging.warning('Running `train_and_evaluate` locally, ignoring '
+ '`experimental_distribute_coordinator_mode`.')
+ return False
+ return True
+
+
+def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
+ """Run distribute coordinator for Estimator's `train_and_evaluate`.
+
+ Args:
+ estimator: An `Estimator` instance to train and evaluate.
+ train_spec: A `TrainSpec` instance to specify the training specification.
+ eval_spec: A `EvalSpec` instance to specify the evaluation and export
+ specification.
+ executor_cls: the evaluation executor class of Estimator.
+
+ Raises:
+ ValueError: if `distribute_coordinator_mode` is None in RunConfig.
+ """
+ run_config = estimator.config
+ if not run_config._distribute_coordinator_mode: # pylint: disable=protected-access
+ raise ValueError(
+ 'Distribute coordinator mode is not specified in `RunConfig`.')
+
+ def _worker_fn(strategy):
+ """Function for worker task."""
+ local_estimator = copy.deepcopy(estimator)
+ # pylint: disable=protected-access
+ local_estimator._config._train_distribute = strategy
+ _init_run_config_from_worker_context(
+ local_estimator._config, dc_context.get_current_worker_context())
+ local_estimator._train_distribution = strategy
+ # pylint: enable=protected-access
+
+ local_estimator.train(
+ input_fn=train_spec.input_fn,
+ max_steps=train_spec.max_steps,
+ hooks=list(train_spec.hooks))
+
+ def _eval_fn(strategy):
+ """Function for evaluator task."""
+ local_estimator = copy.deepcopy(estimator)
+ # pylint: disable=protected-access
+ local_estimator._config._eval_distribute = strategy
+ _init_run_config_from_worker_context(
+ local_estimator._config, dc_context.get_current_worker_context())
+ local_estimator._eval_distribution = strategy
+
+ executor = executor_cls(local_estimator, train_spec, eval_spec)
+ executor._start_continuous_evaluation()
+ # pylint: enable=protected-access
+
+ # pylint: disable=protected-access
+ if (run_config._distribute_coordinator_mode ==
+ dc.CoordinatorMode.STANDALONE_CLIENT):
+ cluster_spec = run_config.cluster_spec
+ assert cluster_spec
+ else:
+ # The cluster_spec comes from TF_CONFIG environment variable if it is
+ # INDEPENDENT_WORKER mode.
+ cluster_spec = None
+
+ dc.run_distribute_coordinator(
+ _worker_fn,
+ run_config.train_distribute,
+ _eval_fn,
+ run_config.eval_distribute,
+ mode=run_config._distribute_coordinator_mode,
+ cluster_spec=cluster_spec,
+ session_config=run_config.session_config)
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index bdabbf4ea3..6f48d38b58 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -237,6 +237,7 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":graph_only_ops",
+ "//tensorflow/python:cond_v2_impl",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 6a327bd010..13fb0e88a6 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -504,9 +504,7 @@ class Context(object):
Args:
fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
"""
- pywrap_tensorflow.TFE_ContextAddFunction(
- self._handle, # pylint: disable=protected-access
- fn)
+ pywrap_tensorflow.TFE_ContextAddFunction(self._handle, fn)
def add_function_def(self, fdef):
"""Add a function definition to the context.
@@ -519,9 +517,7 @@ class Context(object):
"""
fdef_string = fdef.SerializeToString()
pywrap_tensorflow.TFE_ContextAddFunctionDef(
- self._handle, # pylint: disable=protected-access
- fdef_string,
- len(fdef_string))
+ self._handle, fdef_string, len(fdef_string))
def add_post_execution_callback(self, callback):
"""Add a post-execution callback to the context.
@@ -633,14 +629,7 @@ def context():
def context_safe():
- return _context
-
-
-# TODO(agarwal): remove this.
-def get_default_context():
- """Same as context."""
- if _context is None:
- _initialize_context()
+ """Returns current context (or None if one hasn't been initialized)."""
return _context
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index cc765725a4..cbd6f4cb75 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+import pickle
import threading
import numpy as np
@@ -185,6 +187,17 @@ class TFETest(test_util.TensorFlowTestCase):
device_count={'GPU': 0}))
self.assertEquals(0, ctx.num_gpus())
+ def testPickle(self):
+ tmp_dir = self.get_temp_dir()
+ fname = os.path.join(tmp_dir, 't.pickle')
+ with open(fname, 'wb') as f:
+ t = constant_op.constant(10.0)
+ pickle.dump(t, f)
+
+ with open(fname, 'rb') as f:
+ t = pickle.load(f)
+ self.assertAllEqual(t.numpy(), 10.0)
+
def testTensorPlacement(self):
if not context.context().num_gpus():
self.skipTest('No GPUs found')
diff --git a/tensorflow/python/eager/execution_callbacks.py b/tensorflow/python/eager/execution_callbacks.py
index 9a08259653..80ff4459d6 100644
--- a/tensorflow/python/eager/execution_callbacks.py
+++ b/tensorflow/python/eager/execution_callbacks.py
@@ -146,7 +146,7 @@ def inf_nan_callback(op_type,
"""
del attrs, inputs # Not used.
- ctx = context.get_default_context()
+ ctx = context.context()
for index, output in enumerate(outputs):
if not output.dtype.is_numpy_compatible:
@@ -263,12 +263,12 @@ def add_execution_callback(callback):
Return value(s) from the callback are ignored.
"""
execute.execute = execute.execute_with_callbacks
- context.get_default_context().add_post_execution_callback(callback)
+ context.context().add_post_execution_callback(callback)
def clear_execution_callbacks():
"""Clear all execution callbacks from the default eager context."""
- context.get_default_context().clear_post_execution_callbacks()
+ context.context().clear_post_execution_callbacks()
def seterr(inf_or_nan=None):
@@ -309,7 +309,7 @@ def seterr(inf_or_nan=None):
"Valid actions are %s." % (inf_or_nan, _VALID_CALLBACK_ACTIONS))
old_settings = {"inf_or_nan": "ignore"}
- default_context = context.get_default_context()
+ default_context = context.context()
carryover_callbacks = []
for callback in default_context.post_execution_callbacks:
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index e04595f5ed..6c87dccaf1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import sys
import threading
import numpy as np
@@ -33,10 +34,12 @@ from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
@@ -48,6 +51,10 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
+# This is to avoid a circular dependency with cond_v2_impl
+# (function -> gradients_impl -> control_flow_ops -> cond_v2_impl).
+cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+
def create_substitute_placeholder(value, name, dtype=None):
"""Creates a placeholder for `value` and propagates shape info to it."""
@@ -112,10 +119,6 @@ class CapturingGraph(ops.Graph):
# for resource tensors.
self._last_op_using_resource_tensor = {}
- # TODO(apassos) remove once the C API is used by default.
- def _use_c_api_hack(self):
- return True
-
def clear_resource_control_flow_state(self):
self._last_op_using_resource_tensor = {}
@@ -179,6 +182,14 @@ class CapturingGraph(ops.Graph):
compute_device=compute_device)
+def _get_device_functions(ctx, graph):
+ """Returns a tuple of device functions representing the device stack."""
+ if ctx.executing_eagerly():
+ return (pydev.merge_device(ctx.device_name),)
+ else:
+ return tuple(graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+
+
class FuncGraph(CapturingGraph):
"""Graph representing a function body.
@@ -194,14 +205,16 @@ class FuncGraph(CapturingGraph):
by this function. The Tensors in this structure are the same as those of
self.outputs. Note that this structure might contain Python `None`s.
variables: Variables that should be watched during function execution.
+ outer_graph: The graph this function is defined in. May be another FuncGraph
+ or the global default Graph.
seed: The graph-level random seed.
"""
def __init__(self, name):
"""Construct a new FuncGraph.
- The graph will inherit its graph key, collections, seed, and distribution
- strategy stack from the current context or graph.
+ The graph will inherit its graph key, collections, seed, device stack, and
+ distribution strategy stack from the current context or graph.
Args:
name: the name of the function.
@@ -213,19 +226,20 @@ class FuncGraph(CapturingGraph):
self.outputs = []
self.structured_outputs = None
self.variables = []
+ self.outer_graph = ops.get_default_graph()
+
+ graph = self.outer_graph
if context.executing_eagerly():
self.seed = context.global_seed()
self._xla_compile = (context.context().device_spec.device_type == "TPU")
+ self._add_device_to_stack(context.context().device_name)
else:
- graph = ops.get_default_graph()
- # Inherit the graph key, since this is used for matching variables in
- # optimizers.
- self._graph_key = graph._graph_key # pylint: disable=protected-access
self.seed = graph.seed
self._xla_compile = getattr(graph, "_xla_compile", False)
+ self._device_function_stack = graph._device_function_stack.copy() # pylint: disable=protected-access
+ self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access
- graph = ops.get_default_graph()
# TODO(b/112165328, b/112906995): summaries depend on inheriting collections
# from the default graph even in eager mode. It'd be nice to not have a
# default graph with eager execution, so hopefully this will go away when we
@@ -236,6 +250,9 @@ class FuncGraph(CapturingGraph):
# from the default graph even in eager mode. Maybe it should be part of the
# eager context?
self._distribution_strategy_stack = graph._distribution_strategy_stack
+ # Inherit the graph key, since this is used for matching variables in
+ # optimizers.
+ self._graph_key = graph._graph_key
# pylint: enable=protected-access
def capture(self, tensor, name=None):
@@ -248,6 +265,16 @@ class FuncGraph(CapturingGraph):
return internal_tensor
+ @property
+ def external_captures(self):
+ """External tensors captured by this function."""
+ return list(self.captures.keys())
+
+ @property
+ def internal_captures(self):
+ """Placeholders in this function corresponding captured tensors."""
+ return list(self.captures.values())
+
def _forward_name(n):
"""The name of a generated forward defun named n."""
@@ -423,15 +450,15 @@ def _flatten(sequence):
return outputs
-class GraphCallable(object):
+class Function(object):
"""Callable object encapsulating a function definition and its gradient.
- `GraphCallable` is a callable that encapsulates a function definition and
+ `Function` is a callable that encapsulates a function definition and
is differentiable under `tf.GradientTape` objects.
"""
def __init__(self, func_graph, attrs=None):
- """Initialize a GraphCallable.
+ """Initialize a Function.
Args:
func_graph: An instance of FuncGraph: the function body to wrap.
@@ -449,11 +476,13 @@ class GraphCallable(object):
self._output_shapes = tuple(
output.shape for output in self._func_graph.outputs)
self._attrs = attrs or {}
+ self._device_functions = tuple(
+ self._func_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
self._inference_function = _EagerDefinedFunction(
_inference_name(self._func_graph.name), self._func_graph,
self._func_graph.inputs, self._func_graph.outputs, self._attrs)
- self._backward_graph_callable = None
+ self._backward_graph_function = None
# Map holding distributed variables, keyed by resource handle tensors.
self._distributed_variables = {}
@@ -466,14 +495,94 @@ class GraphCallable(object):
for component_variable in component_variables:
self._distributed_variables[component_variable.handle] = variable
+ def __call__(self, *args):
+ """Executes the wrapped function."""
+ ctx = context.context()
+ device_functions = _get_device_functions(ctx, ops.get_default_graph())
+ if device_functions != self._device_functions:
+ raise ValueError(
+ "The current device stack does not match the device stack under "
+ "which the TensorFlow function '%s' was created.\n"
+ "Current device stack: %s\n%s device stack: %s" %
+ (self._inference_function.name, device_functions,
+ self._inference_function.name, self._device_functions))
+
+ for v in self._func_graph.variables:
+ if v.trainable:
+ tape.watch_variable(v)
+
+ captures = self._resolve_captured_inputs()
+ tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
+ args = tensor_inputs + captures
+
+ if tape.should_record(tensor_inputs) or tape.should_record(captures):
+ return self._backprop_call(args)
+
+ outputs = self._inference_function.call(ctx, args)
+ return self._build_call_outputs(outputs)
+
@property
def graph(self):
+ """Returns the graph from which this function was constructed."""
return self._func_graph
@property
def variables(self):
+ """Returns all variables touched by this function."""
return self._func_graph.variables
+ @property
+ def inputs(self):
+ """Returns tensors in `self.graph` corresponding to arguments."""
+ return self._func_graph.inputs
+
+ @property
+ def outputs(self):
+ """Returns tensors in `self.graph` corresponding to return values."""
+ return self._func_graph.outputs
+
+ @property
+ def captured_inputs(self):
+ """Returns external Tensors captured by this function.
+
+ self.__call__(*args) passes `args + self.captured_inputs` to the function.
+ """
+ return self._captured_inputs
+
+ @property
+ def function_def(self):
+ """Returns a `FunctionDef` object representing this function."""
+ return self._inference_function.definition
+
+ @property
+ def output_shapes(self):
+ """The function's output shapes."""
+ # TODO(ebrevdo): Should we only keep the output shapes associated
+ # with len(self._python_returns) outputs?
+ # TODO(akshayka): Consider removing this.
+ outputs_list = nest.flatten(self._func_graph.structured_outputs)
+ j = 0
+ for i, o in enumerate(outputs_list):
+ if o is not None:
+ if isinstance(o, ops.IndexedSlices):
+ # Extract the shape of the `IndexedSlices` object's `values` field.
+ outputs_list[i] = self._output_shapes[j] # the `values` shape
+ if o.dense_shape is not None:
+ j += 3 # skip over shapes for `values`, `indices`, `dense_shape`
+ else:
+ j += 2 # skip over shapes for `values`, `indices`
+ else:
+ outputs_list[i] = self._output_shapes[j]
+ j += 1
+ return nest.pack_sequence_as(self._func_graph.structured_outputs,
+ outputs_list)
+
+ @property
+ def output_dtypes(self):
+ # TODO(akshayka): Consider removing this.
+ return nest.map_structure(lambda x: x.dtype if x is not None else None,
+ self._func_graph.structured_outputs)
+
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
@@ -494,7 +603,7 @@ class GraphCallable(object):
self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of
- # `self._backward_graph_callable` correspond to outputs of
+ # `self._backward_graph_function` correspond to outputs of
# `self._forward_function`.
backwards_graph.inputs = gradients_wrt_outputs + list(
backwards_graph.captures.values())
@@ -503,7 +612,7 @@ class GraphCallable(object):
backwards_graph.outputs.extend(
grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs
- self._backward_graph_callable = GraphCallable(
+ self._backward_graph_function = Function(
backwards_graph, attrs=self._attrs)
def _backprop_call(self, args):
@@ -517,7 +626,7 @@ class GraphCallable(object):
Returns:
The call output.
"""
- if self._backward_graph_callable is None:
+ if self._backward_graph_function is None:
self._construct_backprop_function()
ctx = context.context()
@@ -532,49 +641,12 @@ class GraphCallable(object):
side_outputs = outputs[self._num_outputs:]
def backward_function(*args):
- return self._backward_graph_callable(*(list(args) + side_outputs)) # pylint: disable=not-callable
+ return self._backward_graph_function(*(list(args) + side_outputs)) # pylint: disable=not-callable
tape.record_operation(self._forward_function.signature.name, real_outputs,
args, backward_function)
return self._build_call_outputs(real_outputs)
- @property
- def output_shapes(self):
- """The function's output shapes."""
- # TODO(ebrevdo): Should we only keep the output shapes associated
- # with len(self._python_returns) outputs?
- outputs_list = nest.flatten(self._func_graph.structured_outputs)
- j = 0
- for i, o in enumerate(outputs_list):
- if o is not None:
- if isinstance(o, ops.IndexedSlices):
- # Extract the shape of the `IndexedSlices` object's `values` field.
- outputs_list[i] = self._output_shapes[j] # the `values` shape
- if o.dense_shape is not None:
- j += 3 # skip over shapes for `values`, `indices`, `dense_shape`
- else:
- j += 2 # skip over shapes for `values`, `indices`
- else:
- outputs_list[i] = self._output_shapes[j]
- j += 1
- return nest.pack_sequence_as(self._func_graph.structured_outputs,
- outputs_list)
-
- @property
- def output_dtypes(self):
- return nest.map_structure(lambda x: x.dtype if x is not None else None,
- self._func_graph.structured_outputs)
-
- @property
- def captured_inputs(self):
- # TODO(akshayka): Should this return `_resolve_captured_inputs()`?
- return self._captured_inputs
-
- @property
- def name(self):
- """Returns the name of the function in Eager-compatible format."""
- return self._inference_function.name.encode("utf-8")
-
def _resolve_captured_inputs(self):
"""Resolve captured distributed variables to their current values.
@@ -601,23 +673,6 @@ class GraphCallable(object):
return resolved_captured_inputs
return self._captured_inputs
- def __call__(self, *args):
- """Executes the passed function in eager mode."""
- for v in self._func_graph.variables:
- if v.trainable:
- tape.watch_variable(v)
-
- captures = self._resolve_captured_inputs()
- tensor_inputs = [x for x in nest.flatten(args) if isinstance(x, ops.Tensor)]
- args = tensor_inputs + captures
-
- if tape.should_record(tensor_inputs) or tape.should_record(captures):
- return self._backprop_call(args)
-
- ctx = context.context()
- outputs = self._inference_function.call(ctx, args)
- return self._build_call_outputs(outputs)
-
def _build_call_outputs(self, result):
"""Maps the fdef output list to actual output structure.
@@ -673,7 +728,7 @@ def _get_defun_inputs_from_args(args):
return nest.pack_sequence_as(args, function_inputs)
-def _func_graph_from_py_func(name, python_func, args, kwds, signature=None):
+def func_graph_from_py_func(name, python_func, args, kwds, signature=None):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -831,13 +886,13 @@ def _deterministic_dict_values(dictionary):
return tuple(dictionary[key] for key in sorted(dictionary))
-class _PolymorphicFunction(object):
+class PolymorphicFunction(object):
"""Wrapper class for the graph functions defined for a Python function.
See the documentation for `defun` for more information on the semantics of
defined functions.
- _PolymorphicFunction class is thread-compatible meaning that minimal
+ PolymorphicFunction class is thread-compatible meaning that minimal
usage of defuns (defining and calling) is thread-safe, but if users call other
methods or invoke the base `python_function` themselves, external
synchronization is necessary.
@@ -859,9 +914,6 @@ class _PolymorphicFunction(object):
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
argspec has keyword arguments.
- TypeError: if `input_signature` contains anything other than
- `TensorSpec` objects, or (if not None) is anything other than a tuple or
- list.
"""
if isinstance(python_function, functools.partial):
@@ -873,7 +925,7 @@ class _PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwds_to_include = {}
self._name = name
- self._arguments_to_functions = {}
+ self._function_cache = collections.OrderedDict()
self._variables = []
self._lock = threading.Lock()
@@ -908,15 +960,40 @@ class _PolymorphicFunction(object):
self._input_signature = tuple(input_signature)
self._flat_input_signature = tuple(nest.flatten(input_signature))
- if any(not isinstance(arg, tensor_spec.TensorSpec)
- for arg in self._flat_input_signature):
- raise TypeError("Invalid input_signature %s; input_signature must be "
- "a possibly nested sequence of TensorSpec objects.")
+
+ def __call__(self, *args, **kwds):
+ """Calls a graph function specialized to the inputs."""
+ graph_function, inputs = self._maybe_define_function(*args, **kwds)
+ return graph_function(*inputs)
+
+ @property
+ def python_function(self):
+ """Returns the wrapped Python function."""
+ return self._python_function
+
+ # TODO(akshayka): Remove this property.
+ @property
+ def variables(self):
+ """Returns the union of all variables referenced by cached `Function`s`."""
+ return self._variables
+
+ def get_concrete_function(self, *args, **kwargs):
+ """Returns a `Function` object specialized to inputs and execution context.
+
+ `args` and `kwargs` are ignored if this `PolymorphicFunction` was created
+ with an `input_signature`.
+
+ Args:
+ *args: inputs to specialize on.
+ **kwargs: inputs to specialize on.
+ """
+ graph_function, _ = self._maybe_define_function(*args, **kwargs)
+ return graph_function
def __get__(self, instance, owner):
"""Makes it possible to defun instance methods."""
del owner
- # `instance` here is the instance that this `_PolymorphicFunction` was
+ # `instance` here is the instance that this `PolymorphicFunction` was
# accessed through; e.g., for
#
# class Foo(object):
@@ -926,29 +1003,42 @@ class _PolymorphicFunction(object):
# ...
#
# foo = Foo()
- # foo.bar() # `foo.bar` is a `_PolymorphicFunction` instance
+ # foo.bar() # `foo.bar` is a `PolymorphicFunction` instance
#
# then `instance` will be `foo` (and `owner` will be `Foo`).
return functools.partial(self.__call__, instance)
- def _cache_key(self, args, kwds):
- """Computes the cache key given inputs."""
+ def _cache_key(self, args, kwds, ctx, graph):
+ """Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwds) if kwds else args
cache_key = tuple(_encode_arg(arg) for arg in inputs)
else:
del args, kwds
cache_key = self._flat_input_signature
+
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
- return cache_key + (context.executing_eagerly() or ops.get_default_graph(),)
+ executing_eagerly = ctx.executing_eagerly()
+ execution_context = executing_eagerly or graph
+
+ # Putting the device in the cache key ensures that call-site device
+ # annotations are respected.
+ device_functions = _get_device_functions(ctx, graph)
+
+ # `ops.colocate_with` directives translate into `ops.device` directives when
+ # eager execution is enabled.
+ colocation_stack = (None if executing_eagerly else
+ tuple(graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+
+ return cache_key + (execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwds):
"""Canonicalizes `args` and `kwds`.
Canonicalize the inputs to the Python function using its fullargspec. In
particular, we parse the varags and kwargs that this
- `_PolymorphicFunction` was called with into a tuple corresponding to the
+ `PolymorphicFunction` was called with into a tuple corresponding to the
Python function's positional (named) arguments and a dictionary
corresponding to its kwargs.
@@ -1029,36 +1119,30 @@ class _PolymorphicFunction(object):
"""
args, kwds = self._canonicalize_function_inputs(*args, **kwds)
- cache_key = self._cache_key(args, kwds)
+ cache_key = self._cache_key(args, kwds, context.context(),
+ ops.get_default_graph())
with self._lock:
try:
- graph_function = self._arguments_to_functions.get(cache_key, None)
+ graph_function = self._function_cache.get(cache_key, None)
except TypeError:
raise TypeError("Arguments supplied to `defun`-generated functions "
"must be hashable.")
if graph_function is None:
- graph_function = GraphCallable(
- _func_graph_from_py_func(self._name, self._python_function, args,
- kwds, self._input_signature))
+ graph_function = Function(
+ func_graph_from_py_func(self._name, self._python_function, args,
+ kwds, self._input_signature))
self._variables.extend(
[v for v in graph_function.variables if v not in self._variables])
- self._arguments_to_functions[cache_key] = graph_function
+ self._function_cache[cache_key] = graph_function
return graph_function, (args, kwds)
- def __call__(self, *args, **kwds):
- """Calls a graph function specialized for this input signature."""
- graph_function, inputs = self._maybe_define_function(*args, **kwds)
- return graph_function(*inputs)
-
- def call_python_function(self, *args, **kwargs):
- """Directly calls the wrapped python function."""
- return self._python_function(*args, **kwargs)
- @property
- def variables(self):
- """Returns a list of variables used in any of the defined functions."""
- return self._variables
+def _validate_signature(signature):
+ if any(not isinstance(arg, tensor_spec.TensorSpec)
+ for arg in nest.flatten(signature)):
+ raise TypeError("Invalid input_signature %s; input_signature must be "
+ "a possibly nested sequence of TensorSpec objects.")
def defun(func=None, input_signature=None):
@@ -1136,6 +1220,7 @@ def defun(func=None, input_signature=None):
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.keep_probability = keep_probability
+ @tf.contrib.eager.defun
def call(self, inputs, training=True):
x = self.dense2(self.dense1(inputs))
if training:
@@ -1144,7 +1229,6 @@ def defun(func=None, input_signature=None):
return x
model = MyModel()
- model.call = tf.contrib.eager.defun(model.call)
model(x, training=True) # executes a graph, with dropout
model(x, training=False) # executes a graph, without dropout
@@ -1371,7 +1455,15 @@ def defun(func=None, input_signature=None):
function (and return zero or more `tf.Tensor` objects).
If `func` is None, returns a decorator that, when invoked with a single
`func` argument, returns a callable equivalent to the case above.
+
+ Raises:
+ TypeError: If `input_signature` is neither `None` nor a sequence of
+ `tf.contrib.eager.TensorSpec` objects.
"""
+
+ if input_signature is not None:
+ _validate_signature(input_signature)
+
# TODO(apassos): deal with captured global state. Deal with control flow.
def decorated(function):
try:
@@ -1380,8 +1472,7 @@ def defun(func=None, input_signature=None):
name = "function"
return tf_decorator.make_decorator(
function,
- _PolymorphicFunction(
- function, name, input_signature=input_signature))
+ PolymorphicFunction(function, name, input_signature=input_signature))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1397,52 +1488,6 @@ def defun(func=None, input_signature=None):
return decorated
-def make_defun_op(func, *args, **kwds):
- """Compile func into graph_mode, assuming func arguments are *args, **kwargs.
-
- `make_defun_op` converts a function that constructs a TensorFlow graph into
- a function object and attaches it to the graph. The resulting function
- object can be queried for its properties, and called directly with different
- inputs to execute.
-
- More details on use cases and limitations are available in the
- documentation for `defun`.
-
- Example:
- ```python
- def f(x, y):
- return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
-
- def g(x, y):
- return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
-
- z = tf.constant([[0.0, 0.0]])
- g_op = make_defun_op(g, z, z)
-
- assert g_op.output_shapes == tf.TensorShape([])
- assert g_op.output_types == tf.float32
-
- x = tf.constant([[2.0, 3.0]])
- y = tf.constant([[3.0, -2.0]])
-
- # The plain function and defun-compiled function should return the same value.
- assert f(x, y).numpy() == g_op(x, y).numpy()
- ```
-
- Args:
- func: function to be compiled.
- *args: List arguments to pass to `func` when attaching to the graph.
- **kwds: Keyword arguments to pass to `func` when attaching to the graph.
-
- Returns:
- A wrapper object which can be queried for its output properties,
- and which can be called directly the way a `@defun` wrapped function
- can.
- """
- return GraphCallable(
- _func_graph_from_py_func(func.__name__, func, args, kwds))
-
-
class AutomaticControlDependencies(object):
"""Context manager to automatically add control dependencies.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index ca6aafd715..3c79099d87 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -130,16 +130,16 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
self.assertEqual(f().shape, ())
- def testBasicDefunOpGraphMode(self):
+ def testBasicGraphFunction(self):
matmul = function.defun(math_ops.matmul)
+ @function.defun
def sq(a):
return matmul(a, a)
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq_op = function.make_defun_op(sq, t)
-
+ sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
@@ -223,33 +223,32 @@ class FunctionTest(test.TestCase):
g, = gradients_impl.gradients(f_c, c)
self.assertAllEqual(sess.run(g), [[1.0]])
- def testNestedInputsDefunOpGraphMode(self):
+ def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)
pair = collections.namedtuple('pair', ['a', 'b'])
+ @function.defun
def a_times_b(inputs):
return matmul(inputs.a['a'], inputs.b['b'])
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
-
inputs = pair({'a': t}, {'b': t})
- sq_op = function.make_defun_op(a_times_b, inputs)
-
+ sq_op = a_times_b.get_concrete_function(inputs)
self.assertEqual(sq_op.output_shapes, tensor_shape.TensorShape([2, 2]))
out = sq_op(inputs)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
- def testNestedOutputDefunOpGraphMode(self):
+ def testNestedOutputGraphFunction(self):
matmul = function.defun(math_ops.matmul)
+ @function.defun
def sq(a):
return (matmul(a, a), {'b': constant_op.constant(1.0)})
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
- sq_op = function.make_defun_op(sq, t)
-
+ sq_op = sq.get_concrete_function(t)
self.assertEqual(sq_op.output_shapes,
(tensor_shape.TensorShape([2, 2]),
{'b': tensor_shape.TensorShape([])}))
@@ -259,28 +258,28 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(a, math_ops.matmul(t, t).numpy())
self.assertAllEqual(b['b'].numpy(), 1.0)
- def testDefunOpGraphModeWithGradients(self):
+ def testGraphFunctionWithGradients(self):
v = resource_variable_ops.ResourceVariable(1.0, name='v')
+ @function.defun
def step():
def inner():
return v * v
return backprop.implicit_grad(inner)()[0][0]
- step_op = function.make_defun_op(step)
-
+ step_op = step.get_concrete_function()
self.assertEqual(step_op.output_dtypes, dtypes.float32)
self.assertEqual(step_op.output_shapes, tensor_shape.TensorShape([]))
self.assertAllEqual(step_op(), 2.0)
- def testDefunOpGraphModeNoneOutput(self):
+ def testGraphFunctionNoneOutput(self):
+ @function.defun
def fn(unused_a, unused_b):
return None
x = constant_op.constant(1)
- fn_op = function.make_defun_op(fn, x, x)
-
+ fn_op = fn.get_concrete_function(x, x)
self.assertEqual(fn_op.output_dtypes, None)
self.assertEqual(fn_op.output_shapes, None)
self.assertAllEqual(fn_op(x, x), None)
@@ -321,13 +320,13 @@ class FunctionTest(test.TestCase):
x = random_ops.random_uniform([2, 2]).numpy()
defined = function.defun(f)
defined(x)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
x = random_ops.random_uniform([2, 2]).numpy()
defined(x)
# A NumPy array with different values but the same shape and dtype
# shouldn't trigger another function definition.
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
def testDefunCapturedInt32(self):
x = constant_op.constant(1, dtype=dtypes.int32)
@@ -358,6 +357,47 @@ class FunctionTest(test.TestCase):
self.assertEqual(3.0, float(test_assign_add()))
+ @test_util.run_in_graph_and_eager_modes
+ def testTensorInitializationInFunctionRaisesError(self):
+ error_msg = ('Tensor-typed variable initializers must either be '
+ 'wrapped in an init_scope or callable.*')
+
+ @function.defun
+ def tensor_init():
+ with self.assertRaisesRegexp(ValueError, error_msg):
+ resource_variable_ops.ResourceVariable(constant_op.constant(2.0))
+
+ tensor_init()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testCallableTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ v = resource_variable_ops.ResourceVariable(
+ lambda: constant_op.constant(2.0))
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testInitScopeTensorInitializationInFunction(self):
+
+ @function.defun
+ def tensor_init():
+ with ops.init_scope():
+ const = constant_op.constant(2.0)
+ v = resource_variable_ops.ResourceVariable(const)
+ return v.read_value()
+
+ value = tensor_init()
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
@@ -645,17 +685,19 @@ class FunctionTest(test.TestCase):
def testReturningIndexedSlicesWithDefun(self):
def validate(indexed_slice):
+ @function.defun
def f():
return indexed_slice
- output = function.defun(f)()
+ output = f()
self.assertTrue(isinstance(output, ops.IndexedSlices))
self.assertAllEqual(indexed_slice.values, output.values)
self.assertAllEqual(indexed_slice.indices, output.indices)
self.assertAllEqual(indexed_slice.dense_shape, output.dense_shape)
self.assertEqual(
- function.make_defun_op(f).output_shapes, indexed_slice.values.shape)
+ f.get_concrete_function().output_shapes,
+ indexed_slice.values.shape)
arg = ops.IndexedSlices(
values=constant_op.constant([1, 2]),
@@ -978,39 +1020,109 @@ class FunctionTest(test.TestCase):
config=config_pb2.ConfigProto(device_count={'CPU': 4}))
def testDeviceAnnotationsRespected(self):
- @function.defun
def multi_device_fn():
with ops.device('/cpu:0'):
- s1 = iterator_ops.Iterator.from_structure(
+ s0 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device('/cpu:1'):
- s2 = iterator_ops.Iterator.from_structure(
+ s1 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
with ops.device('/cpu:2'):
- s3 = iterator_ops.Iterator.from_structure(
- (dtypes.float32,)).string_handle()
- with ops.device(''):
- # TODO(akshayka): This is unfortunate and brittle. It prevents
- # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'.
- # Remove this hack once we have a way of obtaining metadata about
- # function execution.
- s4 = iterator_ops.Iterator.from_structure(
+ s2 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
- return s1, s2, s3, s4
+ s3 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ return s0, s1, s2, s3
- with ops.device('/cpu:3'):
- outputs = self.evaluate(multi_device_fn())
+ defined = function.defun(multi_device_fn)
+ outputs = self.evaluate(defined())
+ self.assertEqual(len(defined._function_cache), 1)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
- self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
- with ops.device('/cpu:0'):
- outputs = self.evaluate(multi_device_fn())
+ with ops.device('/cpu:3'):
+ outputs = self.evaluate(defined())
+ self.assertEqual(len(defined._function_cache), 2)
self.assertIn(compat.as_bytes('CPU:0'), outputs[0])
self.assertIn(compat.as_bytes('CPU:1'), outputs[1])
self.assertIn(compat.as_bytes('CPU:2'), outputs[2])
- self.assertIn(compat.as_bytes('CPU:0'), outputs[3])
+ self.assertIn(compat.as_bytes('CPU:3'), outputs[3])
+
+ # This should retrieve the call-site-device agnostic function
+ defined()
+ self.assertEqual(len(defined._function_cache), 2)
+
+ # And this should retrieve the function created for '/cpu:3'
+ with ops.device('/cpu:3'):
+ defined()
+ self.assertEqual(len(defined._function_cache), 2)
+
+ @test_util.run_in_graph_and_eager_modes(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2}))
+ def testCallingGraphFunctionOnIncompatibleDeviceRaisesError(self):
+
+ def func():
+ return constant_op.constant(0)
+
+ defined = function.defun(func)
+ with ops.device('cpu:0'):
+ cpu_graph_function = defined.get_concrete_function()
+
+ with ops.device('cpu:0'):
+ self.assertEqual(
+ self.evaluate(cpu_graph_function()), self.evaluate(func()))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device('cpu:1'):
+ cpu_graph_function()
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device(None):
+ cpu_graph_function()
+
+ default_graph_function = defined.get_concrete_function()
+ self.assertEqual(
+ self.evaluate(default_graph_function()), self.evaluate(func()))
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'The current device stack does not match the device stack under '
+ 'which the TensorFlow function \'.*func.*\' was created.\n'
+ 'Current device stack: .*\n.*func.* device stack.*'):
+ with ops.device('cpu:1'):
+ default_graph_function()
+
+ @test_util.run_in_graph_and_eager_modes
+ def testColocateWithRespected(self):
+ # TODO(b/113291792): Use multiple CPUs instead of a GPU.
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found.')
+
+ with ops.device('cpu:0'):
+ x = constant_op.constant(1.0)
+
+ with ops.device('gpu:0'):
+ y = constant_op.constant(1.0)
+
+ @function.defun
+ def foo():
+ return iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+
+ with ops.colocate_with(x):
+ self.assertIn(compat.as_bytes('CPU:0'), self.evaluate(foo()))
+
+ with ops.colocate_with(y):
+ self.assertIn(compat.as_bytes('GPU:0'), self.evaluate(foo()))
def testVariablesAreTracked(self):
v = resource_variable_ops.ResourceVariable(1.0)
@@ -1039,26 +1151,31 @@ class FunctionTest(test.TestCase):
defined = function.defun(func)
defined(0, baz=20)
+
+ def cache_keys():
+ """Sanitizes cache keys of non-input metadata."""
+ return tuple(key[:3] for key in defined._function_cache)
+
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn((0, 1, 20, True), defined._arguments_to_functions)
+ self.assertIn((0, 1, 20), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn((1, 1, 2, True), defined._arguments_to_functions)
+ self.assertIn((1, 1, 2), cache_keys())
# This matches the previous call.
defined(foo=1)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn((1, 2, 3, True), defined._arguments_to_functions)
+ self.assertIn((1, 2, 3), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
- self.assertEqual(len(defined._arguments_to_functions), 3)
+ self.assertEqual(len(defined._function_cache), 3)
# This matches the previous call.
defined(1, baz=3, bar=2)
- self.assertEqual(len(defined._arguments_to_functions), 3)
+ self.assertEqual(len(defined._function_cache), 3)
def testFunctoolsPartialUnwrappedCorrectly(self):
@@ -1084,7 +1201,7 @@ class FunctionTest(test.TestCase):
defined = function.defun(foo, input_signature=signature)
a = array_ops.ones([2])
out = defined(a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, a)
def bar(a):
@@ -1095,13 +1212,13 @@ class FunctionTest(test.TestCase):
defined = function.defun(bar, input_signature=signature)
a = array_ops.ones([2, 1])
out = defined(a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, a)
# Changing the second dimension shouldn't create a new function.
b = array_ops.ones([2, 3])
out = defined(b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
self.assertAllEqual(out, b)
def testNestedInputSignatures(self):
@@ -1118,7 +1235,7 @@ class FunctionTest(test.TestCase):
a = array_ops.ones([2, 1])
b = array_ops.ones([1])
out = defined([a, a], b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
nest.assert_same_structure(out, [[a, a], b])
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], a)
@@ -1129,7 +1246,7 @@ class FunctionTest(test.TestCase):
b = array_ops.ones([2, 5])
c = array_ops.ones([1])
out = defined([a, b], c)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
nest.assert_same_structure(out, [[a, b], c])
self.assertAllEqual(out[0][0], a)
self.assertAllEqual(out[0][1], b)
@@ -1165,13 +1282,13 @@ class FunctionTest(test.TestCase):
# Signatures must consist exclusively of `TensorSpec` objects.
signature = [(2, 3), tensor_spec.TensorSpec([2, 3], dtypes.float32)]
with self.assertRaisesRegexp(TypeError, 'Invalid input_signature.*'):
- function.defun(foo, input_signature=signature)(1, 2)
+ function.defun(foo, input_signature=signature)
# Signatures must be either lists or tuples on their outermost levels.
signature = {'t1': tensor_spec.TensorSpec([], dtypes.float32)}
with self.assertRaisesRegexp(TypeError, 'input_signature must be either a '
'tuple or a list.*'):
- function.defun(foo, input_signature=signature)(1, 2)
+ function.defun(foo, input_signature=signature)
def testInputsIncompatibleWithSignatureRaisesError(self):
@@ -1225,22 +1342,22 @@ class FunctionTest(test.TestCase):
integer = constant_op.constant(2, dtypes.int64)
out1, out2 = foo(flt, integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt=flt, integer=integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(integer=integer, flt=flt)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
out1, out2 = foo(flt, integer=integer)
- self.assertEqual(len(foo._arguments_to_functions), 1)
+ self.assertEqual(len(foo._function_cache), 1)
self.assertEqual(out1.numpy(), 1.0)
self.assertEqual(out2.numpy(), 2)
@@ -1270,27 +1387,27 @@ class FunctionTest(test.TestCase):
a = constant_op.constant(2.0)
b = constant_op.constant([1.0, 2.0])
one = defined(a, b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
two = defined(a=a, b=b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
three = defined(b=b, a=a)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
four = defined(a, b=b)
- self.assertEqual(len(defined._arguments_to_functions), 1)
+ self.assertEqual(len(defined._function_cache), 1)
# The next call corresponds to a new input signature, hence
# we expect another function to be defined.
five = defined(b, a)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
six = defined(a=b, b=a)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
seven = defined(b=a, a=b)
- self.assertEqual(len(defined._arguments_to_functions), 2)
+ self.assertEqual(len(defined._function_cache), 2)
self.assertAllEqual(one, [1.0, 2.0])
self.assertAllEqual(two, [1.0, 2.0])
@@ -1375,7 +1492,7 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(state, [0])
# Whereas calling the python function directly should create a side-effect.
- side_effecting_function.call_python_function()
+ side_effecting_function.python_function()
self.assertAllEqual(state, [0, 0])
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index a916a75f00..823c4078b8 100644..100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -89,7 +89,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
PyObject* exception);
// Returns the string associated with the passed-in python object.
-char* TFE_GetPythonString(PyObject* o);
+const char* TFE_GetPythonString(PyObject* o);
// Returns a unique id on each call.
int64_t get_uid();
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 2d54555cd3..64cf36d079 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -216,7 +216,7 @@ bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(py_value)) {
Py_ssize_t size = 0;
- char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
+ const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
if (buf == nullptr) return false;
*value = tensorflow::StringPiece(buf, size);
return true;
@@ -825,7 +825,7 @@ int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
return -1;
}
-char* TFE_GetPythonString(PyObject* o) {
+const char* TFE_GetPythonString(PyObject* o) {
if (PyBytes_Check(o)) {
return PyBytes_AsString(o);
}
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 817c8e6848..9fce172bee 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -211,6 +211,9 @@ py_test(
shard_count = 2,
srcs_version = "PY2AND3",
tags = [
+ "manual",
+ "no_oss",
+ "notap",
"optonly",
],
deps = [
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index da9a64c2bc..06593f9520 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -335,8 +335,8 @@ def _check_dense_labels_match_logits_and_reshape(
'Expected labels dimension=%s. Received %s. '
'Suggested Fix:'
'If your classifier expects one-hot encoding label,'
- 'check your n_classes argument to the estimator'
- 'and/or the shape of your label.'
+ 'check your n_classes argument to the estimator '
+ 'and/or the shape of your label. '
'Otherwise, check the shape of your label.' %
(expected_labels_dimension, dim1))
expected_labels_shape = array_ops.concat(
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index f7ee42c7f6..97a02bd1e8 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -120,7 +120,9 @@ class Estimator(object):
warm_start_from=None):
"""Constructs an `Estimator` instance.
- See [estimators](https://tensorflow.org/guide/estimators) for more information.
+ See [estimators](https://tensorflow.org/guide/estimators) for more
+ information.
+
To warm-start an `Estimator`:
```python
@@ -286,8 +288,8 @@ class Estimator(object):
Args:
input_fn: A function that provides input data for training as minibatches.
- See [Premade
- Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
for more information. The function should construct and return one of
the following: * A
`tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
@@ -405,7 +407,8 @@ class Estimator(object):
Args:
input_fn: A function that constructs the input data for evaluation. See
- [Premade Estimators](https://tensorflow.org/guide/premade#create_input_functions}
+ [Premade Estimators](
+ https://tensorflow.org/guide/premade#create_input_functions)
for more information. The
function should construct and return one of the following: * A
`tf.data.Dataset` object: Outputs of `Dataset` object must be a tuple
@@ -431,7 +434,11 @@ class Estimator(object):
Returns:
A dict containing the evaluation metrics specified in `model_fn` keyed by
name, as well as an entry `global_step` which contains the value of the
- global step for which this evaluation was performed.
+ global step for which this evaluation was performed. For canned
+ estimators, the dict contains the `loss` (mean loss per mini-batch) and
+ the `average_loss` (mean loss per sample). Canned classifiers also return
+ the `accuracy`. Canned regressors also return the `label/mean` and the
+ `prediction/mean`.
Raises:
ValueError: If `steps <= 0`.
@@ -488,8 +495,8 @@ class Estimator(object):
input_fn: A function that constructs the features. Prediction continues
until `input_fn` raises an end-of-input exception
(`tf.errors.OutOfRangeError` or `StopIteration`).
- See [Premade
- Estimators](https://tensorflow.org/guide/premade_estimators#create_input_functions)
+ See [Premade Estimators](
+ https://tensorflow.org/guide/premade_estimators#create_input_functions)
for more information. The function should construct and return one of
the following:
@@ -602,6 +609,38 @@ class Estimator(object):
as_text=False,
checkpoint_path=None,
strip_default_attrs=False):
+ # pylint: disable=line-too-long,g-doc-args,g-doc-return-or-yield
+ """Exports inference graph as a `SavedModel` into the given dir.
+
+ Note that `export_to_savedmodel` will be renamed to `export_to_saved_model`
+ in TensorFlow 2.0. At that time, `export_to_savedmodel` without the
+ additional underscore will be available only through tf.compat.v1.
+
+ Please see `tf.estimator.Estimator.export_saved_model` for more information.
+
+ There is one additional arg versus the new method:
+ strip_default_attrs: This parameter is going away in TF 2.0, and
+ the new behavior will automatically strip all default attributes.
+ Boolean. If `True`, default-valued attributes will be
+ removed from the `NodeDef`s. For a detailed guide, see [Stripping
+ Default-Valued Attributes](
+ https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+ """
+ # pylint: enable=line-too-long,g-doc-args,g-doc-return-or-yield
+ return self._export_saved_model_for_mode(
+ export_dir_base,
+ serving_input_receiver_fn,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs,
+ mode=model_fn_lib.ModeKeys.PREDICT)
+
+ def export_saved_model(
+ self, export_dir_base, serving_input_receiver_fn,
+ assets_extra=None,
+ as_text=False,
+ checkpoint_path=None):
# pylint: disable=line-too-long
"""Exports inference graph as a `SavedModel` into the given dir.
@@ -648,28 +687,25 @@ class Estimator(object):
as_text: whether to write the `SavedModel` proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
- strip_default_attrs: Boolean. If `True`, default-valued attributes will be
- removed from the `NodeDef`s. For a detailed guide, see [Stripping
- Default-Valued
- Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
Raises:
ValueError: if no `serving_input_receiver_fn` is provided, no
- `export_outputs`
- are provided, or no checkpoint can be found.
+ `export_outputs` are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
- return self._export_saved_model_for_mode(
+ # TODO(b/111442174): `export_to_savedmodel` will be renamed to
+ # `export_to_saved_model` in TensorFlow 2.0. This function is a wrapper
+ # while staging the new version; do not add any logic here.
+ return self.export_savedmodel(
export_dir_base,
serving_input_receiver_fn,
assets_extra=assets_extra,
as_text=as_text,
checkpoint_path=checkpoint_path,
- strip_default_attrs=strip_default_attrs,
- mode=model_fn_lib.ModeKeys.PREDICT)
+ strip_default_attrs=True)
def _export_saved_model_for_mode(
self, export_dir_base, input_receiver_fn,
@@ -1237,7 +1273,8 @@ class Estimator(object):
# We want to create the iterations variable outside the distribution scope
# as that is just stored on the host and mainly used to drive the loop
# and doesn't need to be a Mirrored/Device variable.
- steps_per_run_variable = training.get_or_create_steps_per_run_variable()
+ if is_tpu_strategy:
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
with self._train_distribution.scope():
random_seed.set_random_seed(self._config.tf_random_seed)
iterator, input_hooks = self._get_iterator_from_input_fn(
@@ -1252,7 +1289,7 @@ class Estimator(object):
if is_tpu_strategy:
# Create a step_fn from the train_op of grouped_estimator_spec
- def step_fn(ctx, features, labels):
+ def step_fn(ctx, features, labels=None):
"""A single step that is passed to run_on_dataset."""
estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
@@ -1277,7 +1314,8 @@ class Estimator(object):
loss = ctx.last_step_outputs['loss']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
else:
- features, labels = iterator.get_next()
+ features, labels = estimator_util.parse_iterator_result(
+ iterator.get_next())
grouped_estimator_spec = self._train_distribution.call_for_each_tower(
self._call_model_fn,
features,
@@ -1466,7 +1504,7 @@ class Estimator(object):
self._eval_distribution.__class__.__name__ == 'TPUStrategy')
if is_tpu_strategy:
- def step_fn(ctx, features, labels):
+ def step_fn(ctx, features, labels=None):
"""Runs one step of the eval computation and captures outputs."""
estimator_spec = self._eval_distribution.call_for_each_tower(
self._call_model_fn, features, labels, model_fn_lib.ModeKeys.EVAL,
@@ -1487,7 +1525,8 @@ class Estimator(object):
eval_dict = ctx.non_tensor_outputs['eval_dict']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
else:
- features, labels = iterator.get_next()
+ features, labels = estimator_util.parse_iterator_result(
+ iterator.get_next())
grouped_estimator_spec = self._eval_distribution.call_for_each_tower(
self._call_model_fn, features, labels,
model_fn_lib.ModeKeys.EVAL, config)
diff --git a/tensorflow/python/estimator/export/export.py b/tensorflow/python/estimator/export/export.py
index 3d171f7811..55aace5fa9 100644
--- a/tensorflow/python/estimator/export/export.py
+++ b/tensorflow/python/estimator/export/export.py
@@ -217,6 +217,29 @@ class TensorServingInputReceiver(
receiver_tensors_alternatives=receiver.receiver_tensors_alternatives)
+class UnsupervisedInputReceiver(ServingInputReceiver):
+ """A return type for a training_input_receiver_fn or eval_input_receiver_fn.
+
+ This differs from SupervisedInputReceiver in that it does not require a set
+ of labels.
+
+ The expected return values are:
+ features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or
+ `SparseTensor`, specifying the features to be passed to the model.
+ receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`
+ or `SparseTensor`, specifying input nodes where this receiver expects to
+ be fed by default. Typically, this is a single placeholder expecting
+ serialized `tf.Example` protos.
+ """
+
+ def __new__(cls, features, receiver_tensors):
+ return super(UnsupervisedInputReceiver, cls).__new__(
+ cls,
+ features=features,
+ receiver_tensors=receiver_tensors,
+ receiver_tensors_alternatives=None)
+
+
class SupervisedInputReceiver(
collections.namedtuple('SupervisedInputReceiver',
['features', 'labels', 'receiver_tensors'])):
@@ -288,13 +311,33 @@ def build_parsing_serving_input_receiver_fn(feature_spec,
def _placeholder_from_tensor(t, default_batch_size=None):
+ """Creates a placeholder that matches the dtype and shape of passed tensor.
+
+ Args:
+ t: Tensor or EagerTensor
+ default_batch_size: the number of query examples expected per batch.
+ Leave unset for variable batch size (recommended).
+
+ Returns:
+ Placeholder that matches the passed tensor.
+ """
batch_shape = tensor_shape.TensorShape([default_batch_size])
shape = batch_shape.concatenate(t.get_shape()[1:])
# Reuse the feature tensor's op name (t.op.name) for the placeholder,
# excluding the index from the tensor's name (t.name):
# t.name = "%s:%d" % (t.op.name, t._value_index)
- return array_ops.placeholder(dtype=t.dtype, shape=shape, name=t.op.name)
+ try:
+ name = t.op.name
+ except AttributeError:
+ # In Eager mode, tensors don't have ops or names, and while they do have
+ # IDs, those are not maintained across runs. The name here is used
+ # primarily for debugging, and is not critical to the placeholder.
+ # So, in order to make this Eager-compatible, continue with an empty
+ # name if none is available.
+ name = None
+
+ return array_ops.placeholder(dtype=t.dtype, shape=shape, name=name)
def _placeholders_from_receiver_tensors_dict(input_vals,
diff --git a/tensorflow/python/estimator/export/export_test.py b/tensorflow/python/estimator/export/export_test.py
index 1d475adb43..3eed1ab163 100644
--- a/tensorflow/python/estimator/export/export_test.py
+++ b/tensorflow/python/estimator/export/export_test.py
@@ -163,6 +163,29 @@ class ServingInputReceiverTest(test_util.TensorFlowTestCase):
_ = export.ServingInputReceiver(feature, receiver_tensor)
+class UnsupervisedInputReceiverTest(test_util.TensorFlowTestCase):
+
+ # Since this is basically a wrapper around ServingInputReceiver, we only
+ # have a simple sanity check to ensure that it works.
+
+ def test_unsupervised_input_receiver_constructor(self):
+ """Tests that no errors are raised when input is expected."""
+ features = {
+ "feature0":
+ constant_op.constant([0]),
+ u"feature1":
+ constant_op.constant([1]),
+ "feature2":
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
+ }
+ receiver_tensors = {
+ "example0": array_ops.placeholder(dtypes.string, name="example0"),
+ u"example1": array_ops.placeholder(dtypes.string, name="example1"),
+ }
+ export.UnsupervisedInputReceiver(features, receiver_tensors)
+
+
class SupervisedInputReceiverTest(test_util.TensorFlowTestCase):
def test_input_receiver_constructor(self):
@@ -393,6 +416,7 @@ class ExportTest(test_util.TensorFlowTestCase):
tensor_shape.unknown_shape(),
v.receiver_tensors["feature_2"].shape)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_serving_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -411,6 +435,7 @@ class ExportTest(test_util.TensorFlowTestCase):
dtypes.int32,
serving_input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -431,6 +456,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(
dtypes.int32, input_receiver.receiver_tensors["feature_2"].dtype)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_raw_tensors(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -454,6 +480,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["input", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_batch_size(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -466,6 +493,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual([10], input_receiver.receiver_tensors["feature_1"].shape)
self.assertEqual([10], input_receiver.features["feature_1"].shape)
+ @test_util.run_in_graph_and_eager_modes
def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self):
features = {"feature_1": constant_op.constant(["hello"]),
"feature_2": constant_op.constant([42])}
@@ -474,6 +502,7 @@ class ExportTest(test_util.TensorFlowTestCase):
with self.assertRaises(ValueError):
export.build_raw_supervised_input_receiver_fn(features, labels)
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn(self):
def dummy_input_fn():
return ({"x": constant_op.constant([[1], [1]]),
@@ -491,6 +520,7 @@ class ExportTest(test_util.TensorFlowTestCase):
self.assertEqual(set(["x", "y", "label"]),
set(input_receiver.receiver_tensors.keys()))
+ @test_util.run_in_graph_and_eager_modes
def test_build_supervised_input_receiver_fn_from_input_fn_args(self):
def dummy_input_fn(feature_key="x"):
return ({feature_key: constant_op.constant([[1], [1]]),
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index ce6ad47c01..6361c6acc1 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -36,7 +36,6 @@ from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics as metrics_module
-from tensorflow.python.ops import variables as variables_module
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
@@ -315,15 +314,7 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
- # We are using global variables collection here because:
- # estimator runs eager mode under context.graph_mode() context manager
- # When we try to get all the TF optimizer variables using
- # optimizer.variables() we try to return variables that belong to the
- # current graph. This check (variable.op.graph is current_graph) will
- # error as the context is graph mode but variables are eager.
- # TODO(psv): investigate this and see if we can remove the usage of
- # collection here.
- K._initialize_variables(sess, variables_module.global_variables())
+ K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
diff --git a/tensorflow/python/estimator/run_config.py b/tensorflow/python/estimator/run_config.py
index 220c3e58ca..b1ca207b62 100644
--- a/tensorflow/python/estimator/run_config.py
+++ b/tensorflow/python/estimator/run_config.py
@@ -26,6 +26,7 @@ import six
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat_internal
@@ -51,6 +52,7 @@ _DEFAULT_REPLACEABLE_LIST = [
'device_fn',
'protocol',
'eval_distribute',
+ 'experimental_distribute',
]
_SAVE_CKPT_ERR = (
@@ -331,7 +333,8 @@ class RunConfig(object):
train_distribute=None,
device_fn=None,
protocol=None,
- eval_distribute=None):
+ eval_distribute=None,
+ experimental_distribute=None):
"""Constructs a RunConfig.
All distributed training related properties `cluster_spec`, `is_chief`,
@@ -458,7 +461,8 @@ class RunConfig(object):
train_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during training,
- according to the policy specified by that strategy.
+ according to the policy specified by that strategy. Setting
+ `experimental_distribute.train_distribute` is preferred.
device_fn: A callable invoked for every `Operation` that takes the
`Operation` and returns the device string. If `None`, defaults to
the device function returned by `tf.train.replica_device_setter`
@@ -468,7 +472,13 @@ class RunConfig(object):
eval_distribute: An optional instance of
`tf.contrib.distribute.DistributionStrategy`. If specified,
then Estimator will distribute the user's model during evaluation,
- according to the policy specified by that strategy.
+ according to the policy specified by that strategy. Setting
+ `experimental_distribute.eval_distribute` is preferred.
+ experimental_distribute: an optional
+ `tf.contrib.distribute.DistributeConfig` object specifying
+ DistributionStrategy-related configuration. The `train_distribute` and
+ `eval_distribute` can be passed as parameters to `RunConfig` or set in
+ `experimental_distribute` but not both.
Raises:
ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`
@@ -508,11 +518,15 @@ class RunConfig(object):
train_distribute=train_distribute,
device_fn=device_fn,
protocol=protocol,
- eval_distribute=eval_distribute)
+ eval_distribute=eval_distribute,
+ experimental_distribute=experimental_distribute)
- self._init_distributed_setting_from_environment_var(tf_config)
-
- self._maybe_overwrite_session_config_for_distributed_training()
+ if train_distribute or eval_distribute or experimental_distribute:
+ logging.info('Initializing RunConfig with distribution strategies.')
+ distribute_coordinator_training.init_run_config(self, tf_config)
+ else:
+ self._init_distributed_setting_from_environment_var(tf_config)
+ self._maybe_overwrite_session_config_for_distributed_training()
def _maybe_overwrite_session_config_for_distributed_training(self):
"""Overwrites the session_config for distributed training.
@@ -810,6 +824,7 @@ class RunConfig(object):
- `device_fn`,
- `protocol`.
- `eval_distribute`,
+ - `experimental_distribute`,
In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`
can be set (should not be both).
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index e6bd263c80..240be5dabe 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -26,6 +26,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import estimator_training as distribute_coordinator_training
from tensorflow.python.estimator import estimator as estimator_lib
from tensorflow.python.estimator import exporter as exporter_lib
from tensorflow.python.estimator import run_config as run_config_lib
@@ -274,8 +275,10 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
evaluation `input_fn`, steps, etc.
This utility function provides consistent behavior for both local
- (non-distributed) and distributed configurations. Currently, the only
- supported distributed training configuration is between-graph replication.
+ (non-distributed) and distributed configurations. The default distribution
+ configuration is parameter server-based between-graph replication. For other
+ types of distribution configurations such as all-reduce training, please use
+ [DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute). # pylint: disable=line-too-long
Overfitting: In order to avoid overfitting, it is recommended to set up the
training `input_fn` to shuffle the training data properly.
@@ -426,6 +429,11 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
}'
```
+ When `distribute` or `experimental_distribute.train_distribute` and
+ `experimental_distribute.remote_cluster` is set, this method will start a
+ client running on the current host which connects to the `remote_cluster` for
+ training and evaluation.
+
Args:
estimator: An `Estimator` instance to train and evaluate.
train_spec: A `TrainSpec` instance to specify the training specification.
@@ -444,8 +452,16 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
executor = _TrainingExecutor(
estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
-
config = estimator.config
+
+ # If `distribute_coordinator_mode` is set and running in distributed
+ # environment, we run `train_and_evaluate` via distribute coordinator.
+ if distribute_coordinator_training.should_run_distribute_coordinator(config):
+ logging.info('Running `train_and_evaluate` with Distribute Coordinator.')
+ distribute_coordinator_training.train_and_evaluate(
+ estimator, train_spec, eval_spec, _TrainingExecutor)
+ return
+
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
config.task_id > 0):
raise ValueError(
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index d4a75478d5..31e4778e72 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -109,13 +109,17 @@ def parse_input_fn_result(result):
else:
input_hooks.append(_DatasetInitializerHook(iterator))
result = iterator.get_next()
+ return parse_iterator_result(result) + (input_hooks,)
+
+def parse_iterator_result(result):
+ """Gets features, labels from result."""
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(
'input_fn should return (features, labels) as a len 2 tuple.')
- return result[0], result[1], input_hooks
- return result, None, input_hooks
+ return result[0], result[1]
+ return result, None
class _DatasetInitializerHook(training.SessionRunHook):
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 9d2babc6e0..9b482237ab 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -2747,6 +2747,62 @@ class FunctionalInputLayerTest(test.TestCase):
variables_lib.Variable)
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+ def test_fills_cols_to_vars_shared_embedding(self):
+ # Provide 5 DenseColumn's to input_layer: a NumericColumn, a
+ # BucketizedColumn, an EmbeddingColumn, two SharedEmbeddingColumns. The
+ # EmbeddingColumn creates a Variable and the two SharedEmbeddingColumns
+ # shared one variable.
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ shared_embedding_a, shared_embedding_b = fc.shared_embedding_columns(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ with ops.Graph().as_default():
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ 'aaa':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2)),
+ 'bbb':
+ sparse_tensor.SparseTensor(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 1),
+ dense_shape=(2, 2)),
+ }
+ cols_to_vars = {}
+ all_cols = [
+ price1, dense_feature_bucketized, some_embedding_column,
+ shared_embedding_a, shared_embedding_b
+ ]
+ fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
+ self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
+ self.assertEqual(0, len(cols_to_vars[price1]))
+ self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
+ self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(1, len(cols_to_vars[shared_embedding_a]))
+ # This is a bug in the current implementation and should be fixed in the
+ # new one.
+ self.assertEqual(0, len(cols_to_vars[shared_embedding_b]))
+ self.assertIsInstance(cols_to_vars[some_embedding_column][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
+ self.assertIsInstance(cols_to_vars[shared_embedding_a][0],
+ variables_lib.Variable)
+ self.assertAllEqual(cols_to_vars[shared_embedding_a][0].shape, [3, 2])
+
def test_fills_cols_to_vars_partitioned_variables(self):
price1 = fc.numeric_column('price1')
dense_feature = fc.numeric_column('dense_feature')
@@ -2772,6 +2828,10 @@ class FunctionalInputLayerTest(test.TestCase):
self.assertEqual(0, len(cols_to_vars[price1]))
self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
+ self.assertEqual(
+ 'input_from_feature_columns/input_layer/sparse_feature_embedding/'
+ 'embedding_weights/part_0:0',
+ cols_to_vars[some_embedding_column][0].name)
self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
@@ -5544,20 +5604,6 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertIsNone(partition_info)
return embedding_values
- # Expected lookup result, using combiner='mean'.
- expected_lookups_a = (
- # example 0:
- (7., 11.), # ids [2], embedding = [7, 11]
- # example 1:
- (2., 3.5), # ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
- )
- expected_lookups_b = (
- # example 0:
- (1., 2.), # ids [0], embedding = [1, 2]
- # example 1:
- (0., 0.), # ids [], embedding = [0, 0]
- )
-
# Build columns.
categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index b6bf516286..aa66ed77e9 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -142,6 +142,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -155,7 +156,6 @@ from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
-from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -164,67 +164,148 @@ from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
-def _internal_input_layer(features,
- feature_columns,
- weight_collections=None,
- trainable=True,
- cols_to_vars=None,
- scope=None):
- """See input_layer. `scope` is a name or variable scope to use."""
+class StateManager(object):
+ """Manages the state associated with FeatureColumns.
- feature_columns = fc_old._normalize_feature_columns(feature_columns) # pylint: disable=protected-access
- for column in feature_columns:
- if not isinstance(column, fc_old._DenseColumn): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be a _DenseColumn. '
- 'You can wrap a categorical column with an '
- 'embedding_column or indicator_column. Given: {}'.format(column))
- weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in weight_collections:
- weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
- weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- # a non-None `scope` can allow for variable reuse, when, e.g., this function
- # is wrapped by a `make_template`.
- with variable_scope.variable_scope(
- scope, default_name='input_layer', values=features.values()):
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- output_tensors = []
- ordered_columns = []
- for column in sorted(feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name): # pylint: disable=protected-access
- tensor = column._get_dense_tensor( # pylint: disable=protected-access
- builder,
- weight_collections=weight_collections,
- trainable=trainable)
- num_elements = column._variable_shape.num_elements() # pylint: disable=protected-access
- batch_size = array_ops.shape(tensor)[0]
- output_tensors.append(
- array_ops.reshape(tensor, shape=(batch_size, num_elements)))
- if cols_to_vars is not None:
- # Retrieve any variables created (some _DenseColumn's don't create
- # variables, in which case an empty list is returned).
- cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES,
- scope=variable_scope.get_variable_scope().name)
- _verify_static_batch_size_equality(output_tensors, ordered_columns)
- return array_ops.concat(output_tensors, 1)
+ Some `FeatureColumn`s create variables or resources to assist their
+ computation. The `StateManager` is responsible for creating and storing these
+ objects since `FeatureColumn`s are supposed to be stateless configuration
+ only.
+ """
+
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ """Creates a new variable.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ shape: variable shape.
+ dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
+ trainable: Whether this variable is trainable or not.
+ initializer: initializer instance (callable).
+
+ Returns:
+ The created variable.
+ """
+ del feature_column, name, shape, dtype, trainable, initializer
+ raise NotImplementedError('StateManager.create_variable')
+
+ def add_variable(self, feature_column, var):
+ """Adds an existing variable to the state.
+
+ Args:
+ feature_column: A `FeatureColumn` object to associate this variable with.
+ var: The variable.
+ """
+ del feature_column, var
+ raise NotImplementedError('StateManager.add_variable')
+
+ def get_variable(self, feature_column, name):
+ """Returns an existing variable.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: variable name.
+ """
+ del feature_column, name
+ raise NotImplementedError('StateManager.get_var')
+
+ def add_resource(self, feature_column, name, resource):
+ """Creates a new resource.
+
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this resource corresponds to.
+ name: Name of the resource.
+ resource: The resource.
+
+ Returns:
+ The created resource.
+ """
+ del feature_column, name, resource
+ raise NotImplementedError('StateManager.add_resource')
+ def get_resource(self, feature_column, name):
+ """Returns an already created resource.
-def input_layer(features,
- feature_columns,
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a dense `Tensor` as input layer based on given `feature_columns`.
+ Resources can be things such as tables etc.
+
+ Args:
+ feature_column: A `FeatureColumn` object this variable corresponds to.
+ name: Name of the resource.
+ """
+ del feature_column, name
+ raise NotImplementedError('StateManager.get_resource')
+
+
+class _InputLayerStateManager(StateManager):
+ """Manages the state of InputLayer."""
+
+ def __init__(self, layer, feature_columns, trainable):
+ """Creates an _InputLayerStateManager object.
+
+ Args:
+ layer: The input layer this state manager is associated with.
+ feature_columns: List of feature columns for the input layer
+ trainable: Whether by default, variables created are trainable or not.
+ """
+ self._trainable = trainable
+ self._layer = layer
+ self._cols_to_vars_map = {}
+ self._cols_to_names_map = {}
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ self._cols_to_vars_map[column] = {}
+ base_name = column.name
+ if isinstance(column, SharedEmbeddingColumn):
+ base_name = column.shared_collection_name
+ with variable_scope.variable_scope(base_name) as vs:
+ self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ if name in self._cols_to_vars_map[feature_column]:
+ raise ValueError('Variable already exists.')
+ with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
+
+ def get_variable(self, feature_column, name):
+ if name in self._cols_to_vars_map[feature_column]:
+ return self._cols_to_vars_map[feature_column][name]
+ raise ValueError('Variable does not exist.')
+
+
+class FeatureLayer(Layer):
+ """A layer that produces a dense `Tensor` based on given `feature_columns`.
Generally a single example in training data is described with FeatureColumns.
At the first layer of the model, this column oriented data should be converted
to a single `Tensor`.
+ This layer can be called multiple times with different features.
+
Example:
```python
@@ -233,105 +314,122 @@ def input_layer(features,
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- dense_tensor = input_layer(features, columns)
+ feature_layer = FeatureLayer(columns)
+ dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)
- ```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values can be a `SparseTensor` or a `Tensor` depends on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_DenseColumn` such as `numeric_column`, `embedding_column`,
- `bucketized_column`, `indicator_column`. If you have categorical features,
- you can wrap them with an `embedding_column` or `indicator_column`.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to list of `Variable`s. For example, after
- the call, we might have cols_to_vars =
- {_EmbeddingColumn(
- categorical_column=_HashedCategoricalColumn(
- key='sparse_feature', hash_bucket_size=5, dtype=tf.string),
- dimension=10): [<tf.Variable 'some_variable:0' shape=(5, 10),
- <tf.Variable 'some_variable:1' shape=(5, 10)]}
- If a column creates no variables, its value will be an empty list.
-
- Returns:
- A `Tensor` which represents input layer of a model. Its shape
- is (batch_size, first_layer_dimension) and its dtype is `float32`.
- first_layer_dimension is determined based on given `feature_columns`.
-
- Raises:
- ValueError: if an item in `feature_columns` is not a `_DenseColumn`.
- """
- return _internal_input_layer(features, feature_columns, weight_collections,
- trainable, cols_to_vars)
-
-
-# TODO(akshayka): InputLayer should be a subclass of Layer, and it
-# should implement the logic in input_layer using Layer's build-and-call
-# paradigm; input_layer should create an instance of InputLayer and
-# return the result of invoking its apply method, just as functional layers do.
-class InputLayer(object):
- """An object-oriented version of `input_layer` that reuses variables."""
+ prediction = tf.layers.dense(dense_tensor, 1)."""
def __init__(self,
feature_columns,
- weight_collections=None,
trainable=True,
- cols_to_vars=None):
- """See `input_layer`."""
+ name=None,
+ shared_state_manager=None,
+ **kwargs):
+ """Constructs a FeatureLayer.
- self._feature_columns = feature_columns
- self._weight_collections = weight_collections
- self._trainable = trainable
- self._cols_to_vars = cols_to_vars
- self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
- self._scope = self._input_layer_template.variable_scope
-
- def __call__(self, features):
- return self._input_layer_template(
- features=features,
- feature_columns=self._feature_columns,
- weight_collections=self._weight_collections,
- trainable=self._trainable,
- cols_to_vars=None,
- scope=self._scope)
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `DenseColumn` such as `numeric_column`, `embedding_column`,
+ `bucketized_column`, `indicator_column`. If you have categorical
+ features, you can wrap them with an `embedding_column` or
+ `indicator_column`.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the FeatureLayer.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. The state of SharedEmbeddingColumns, unlike
+ regular embedding columns cannot be owned by the InputLayer itself since
+ SharedEmbeddingColumns can be shared across different InputLayers. As a
+ result users are expected to create a SharedEmbeddingStateManager object
+ which would be responsible for managing the shared state and can be
+ passed into different InputLayer objects to share state. For example,
+
+ ```python
+ sc_1, sc_2 = shared_embedding_column_v2(...)
+ sc_3, sc_4 = shared_embedding_column_v2(...)
+ ssm = SharedEmbeddingStateManager()
+ feature_layer1 = FeatureLayer([sc_1, sc_3], ...,
+ shared_state_manager=ssm)
+ feature_layer2 = FeatureLayer([sc_2, sc_4], ...,
+ shared_state_manager=ssm)
+ ```
+ now input_layer1 and input_layer2 will share variables across. If
+ sharing is not desired, one can create 2 separate
+ SharedEmbeddingStateManager objects
+
+ ```python
+ ssm1 = SharedEmbeddingStateManager()
+ ssm2 = SharedEmbeddingStateManager()
+ feature_layer1 = FeatureLayer([sc_1, sc_3], ...,
+ shared_state_manager=ssm1)
+ feature_layer2 = FeatureLayer([sc_2, sc_4], ...,
+ shared_state_manager=ssm2)
+ ```
+ **kwargs: Keyword arguments to construct a layer.
- @property
- def non_trainable_variables(self):
- return self._input_layer_template.non_trainable_variables
+ Raises:
+ ValueError: if an item in `feature_columns` is not a `DenseColumn`.
+ """
+ super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
- @property
- def non_trainable_weights(self):
- return self._input_layer_template.non_trainable_weights
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._state_manager = _InputLayerStateManager(self, self._feature_columns,
+ self.trainable)
+ self._shared_state_manager = shared_state_manager
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ if not isinstance(column, DenseColumn):
+ raise ValueError(
+ 'Items of feature_columns must be a DenseColumn. '
+ 'You can wrap a categorical column with an '
+ 'embedding_column or indicator_column. Given: {}'.format(column))
- @property
- def trainable_variables(self):
- return self._input_layer_template.trainable_variables
+ def build(self, _):
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ else:
+ with variable_scope.variable_scope(None, default_name=self.name):
+ column.create_state(self._state_manager)
+ super(FeatureLayer, self).build(None)
- @property
- def trainable_weights(self):
- return self._input_layer_template.trainable_weights
+ def call(self, features, cols_to_output_tensors=None):
+ """Returns a dense tensor corresponding to the `feature_columns`.
- @property
- def variables(self):
- return self._input_layer_template.variables
+ Args:
+ features: A mapping from key to tensors. `FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values can be a `SparseTensor` or a `Tensor` depends
+ on corresponding `FeatureColumn`.
+ cols_to_output_tensors: If not `None`, this will be filled with a dict
+ mapping feature columns to output tensors created.
- @property
- def weights(self):
- return self._input_layer_template.weights
+ Returns:
+ A `Tensor` which represents input layer of a model. Its shape
+ is (batch_size, first_layer_dimension) and its dtype is `float32`.
+ first_layer_dimension is determined based on given `feature_columns`.
+ """
+ transformation_cache = FeatureTransformationCache(features)
+ output_tensors = []
+ ordered_columns = []
+ for column in sorted(self._feature_columns, key=lambda x: x.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
+
+ _verify_static_batch_size_equality(output_tensors, ordered_columns)
+ return array_ops.concat(output_tensors, 1)
def linear_model(features,
@@ -565,12 +663,15 @@ class _BiasLayer(base.Layer):
return self._bias_variable
-def _get_expanded_variable_list(variable):
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- return [variable] # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- return list(variable)
+def _get_expanded_variable_list(var_list):
+ returned_list = []
+ for variable in var_list:
+ if (isinstance(variable, variables.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ returned_list.append(variable) # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ returned_list.extend(list(variable))
+ return returned_list
def _strip_leading_slashes(name):
@@ -661,7 +762,7 @@ class _LinearModel(training.Model):
scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
name='weighted_sum')
bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list(bias)
+ self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
return predictions
def _add_layers(self, layers):
@@ -877,10 +978,15 @@ def embedding_column(
trainable=trainable)
-def shared_embedding_columns(
- categorical_columns, dimension, combiner='mean', initializer=None,
- shared_embedding_collection_name=None, ckpt_to_load_from=None,
- tensor_name_in_ckpt=None, max_norm=None, trainable=True):
+def shared_embedding_columns_v2(categorical_columns,
+ dimension,
+ combiner='mean',
+ initializer=None,
+ shared_embedding_collection_name=None,
+ ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None,
+ max_norm=None,
+ trainable=True):
"""List of dense columns that convert from sparse, categorical input.
This is similar to `embedding_column`, except that it produces a list of
@@ -1803,51 +1909,6 @@ def crossed_column(keys, hash_bucket_size, hash_key=None):
keys=tuple(keys), hash_bucket_size=hash_bucket_size, hash_key=hash_key)
-class StateManager(object):
- """Manages the state associated with FeatureColumns.
-
- Some `FeatureColumn`s create variables or resources to assist their
- computation. The `StateManager` is responsible for creating and storing these
- objects since `FeatureColumn`s are supposed to be stateless configuration
- only.
- """
-
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
- """Creates a new variable or returns an existing one.
-
- Args:
- feature_column: A `FeatureColumn` object this variable corresponds to.
- name: variable name.
- shape: variable shape.
- dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
- initializer: initializer instance (callable).
-
- Returns:
- The variable.
- """
- raise NotImplementedError('StateManager.get_variable')
-
- def get_resource(self, feature_column, name, resource_creator):
- """Creates a new resource or returns an existing one.
-
- Resources can be things such as tables etc.
-
- Args:
- feature_column: A `FeatureColumn` object this variable corresponds to.
- name: Name of the resource.
- resource_creator: A callable that can create the resource.
-
- Returns:
- The resource.
- """
- raise NotImplementedError('StateManager.get_resource')
-
-
class FeatureColumn(object):
"""Represents a feature column abstraction.
@@ -2550,6 +2611,17 @@ class EmbeddingColumn(
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ def create_state(self, state_manager):
+ """Creates the embedding lookup variable."""
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ state_manager.create_variable(
+ self,
+ name='embedding_weights',
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ trainable=self.trainable,
+ initializer=self.initializer)
+
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows the signature of _get_dense_tensor."""
# Get sparse IDs and weights.
@@ -2558,13 +2630,8 @@ class EmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
embedding_weights = state_manager.get_variable(
- self,
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer)
+ self, name='embedding_weights')
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
@@ -2637,6 +2704,68 @@ def _get_graph_for_variable(var):
return var.graph
+class SharedEmbeddingStateManager(Layer):
+ """A state manager that handle the state of shared embedding columns.
+
+ This can handle multiple sets of columns that share variables."""
+
+ def __init__(self, trainable=True, name=None, **kwargs):
+ """Constructs a `SharedEmbeddingStateManager`.
+
+ Args:
+ trainable: If true, variables created are trainable.
+ name: Name of the State Manager.
+ **kwargs: Keyword arguments.
+ """
+ super(SharedEmbeddingStateManager, self).__init__(
+ name=name, trainable=trainable, **kwargs)
+ self._var_dict = {}
+
+ def create_variable(self,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
+ """Creates a variable.
+
+ Makes sure only one var is created per `shared_collection_name`. `name` is
+ ignored here as the variable is named `shared_collection_name` instead.
+
+ Args:
+ name: Name of the variable. Not used.
+ shape: Variable shape.
+ dtype: Variable type.
+ trainable: If variable created should be trainable or not.
+ initializer: Variable initializer.
+
+ Returns:
+ A variable or partitioned variable.
+ """
+ if name in self._var_dict:
+ var = self._var_dict[name]
+ return var
+ with variable_scope.variable_scope(
+ self.name, reuse=variable_scope.AUTO_REUSE):
+ var = self.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ trainable=self.trainable and trainable,
+ initializer=initializer,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._var_dict[name] = var
+ return var
+
+ def get_variable(self, feature_column, name):
+ if name not in self._var_dict:
+ raise ValueError('Variable name: {} not recognized.'.format(name))
+ return self._var_dict[name]
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(
@@ -2675,6 +2804,16 @@ class SharedEmbeddingColumn(
"""See `DenseColumn` base class."""
return tensor_shape.vector(self.dimension)
+ def create_state(self, state_manager):
+ """Creates the shared embedding lookup variable."""
+ embedding_shape = (self.categorical_column.num_buckets, self.dimension)
+ state_manager.create_variable(
+ name=self.shared_collection_name,
+ shape=embedding_shape,
+ dtype=dtypes.float32,
+ trainable=self.trainable,
+ initializer=self.initializer)
+
def _get_dense_tensor_internal(self, transformation_cache, state_manager):
"""Private method that follows the signature of _get_dense_tensor."""
# This method is called from a variable_scope with name _var_scope_name,
@@ -2687,13 +2826,8 @@ class SharedEmbeddingColumn(
sparse_ids = sparse_tensors.id_tensor
sparse_weights = sparse_tensors.weight_tensor
- embedding_shape = (self.categorical_column.num_buckets, self.dimension)
embedding_weights = state_manager.get_variable(
- self,
- name='embedding_weights',
- shape=embedding_shape,
- dtype=dtypes.float32,
- initializer=self.initializer)
+ self, name=self.shared_collection_name)
if self.ckpt_to_load_from is not None:
to_restore = embedding_weights
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index ad578d287a..6b343ecf3e 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -33,12 +33,12 @@ from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
+from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
+from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
+from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
from tensorflow.python.feature_column.feature_column_v2 import FeatureTransformationCache
-from tensorflow.python.feature_column.feature_column_v2 import InputLayer
from tensorflow.python.feature_column.feature_column_v2 import StateManager
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
-from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -824,22 +824,6 @@ class HashedCategoricalColumnTest(test.TestCase):
self.assertEqual(
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_hash_bucket('aaa', 10)
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column._get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
hashed_sparse = fc.categorical_column_with_hash_bucket('wire', 10)
transformation_cache = FeatureTransformationCache({
@@ -2640,13 +2624,13 @@ class _LinearModelTest(test.TestCase):
sess.run(net, feed_dict={features['price']: np.array(1)})
-class InputLayerTest(test.TestCase):
+class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def test_retrieving_input(self):
features = {'a': [0.]}
- input_layer = InputLayer(fc_old.numeric_column('a'))
- inputs = self.evaluate(input_layer(features))
+ feature_layer = FeatureLayer(fc.numeric_column('a'))
+ inputs = self.evaluate(feature_layer(features))
self.assertAllClose([[0.]], inputs)
def test_reuses_variables(self):
@@ -2657,7 +2641,7 @@ class InputLayerTest(test.TestCase):
dense_shape=(3, 3))
# Create feature columns (categorical and embedding).
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='a', num_buckets=3)
embedding_dimension = 2
def _embedding_column_initializer(shape, dtype, partition_info):
@@ -2670,16 +2654,16 @@ class InputLayerTest(test.TestCase):
(1, 1)) # id 2
return embedding_values
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- input_layer = InputLayer([embedding_column])
+ feature_layer = FeatureLayer([embedding_column])
features = {'a': sparse_input}
- inputs = input_layer(features)
- variables = input_layer.variables
+ inputs = feature_layer(features)
+ variables = feature_layer.variables
# Sanity check: test that the inputs are correct.
self.assertAllEqual([[1, 0], [0, 1], [1, 1]], inputs)
@@ -2687,13 +2671,13 @@ class InputLayerTest(test.TestCase):
# Check that only one variable was created.
self.assertEqual(1, len(variables))
- # Check that invoking input_layer on the same features does not create
+ # Check that invoking feature_layer on the same features does not create
# additional variables
- _ = input_layer(features)
+ _ = feature_layer(features)
self.assertEqual(1, len(variables))
- self.assertEqual(variables[0], input_layer.variables[0])
+ self.assertEqual(variables[0], feature_layer.variables[0])
- def test_feature_column_input_layer_gradient(self):
+ def test_feature_column_feature_layer_gradient(self):
with context.eager_mode():
sparse_input = sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (2, 0)),
@@ -2701,7 +2685,7 @@ class InputLayerTest(test.TestCase):
dense_shape=(3, 3))
# Create feature columns (categorical and embedding).
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='a', num_buckets=3)
embedding_dimension = 2
@@ -2715,16 +2699,16 @@ class InputLayerTest(test.TestCase):
(1, 1)) # id 2
return embedding_values
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_embedding_column_initializer)
- input_layer = InputLayer([embedding_column])
+ feature_layer = FeatureLayer([embedding_column])
features = {'a': sparse_input}
def scale_matrix():
- matrix = input_layer(features)
+ matrix = feature_layer(features)
return 2 * matrix
# Sanity check: Verify that scale_matrix returns the correct output.
@@ -2739,185 +2723,139 @@ class InputLayerTest(test.TestCase):
self.assertAllEqual([0, 1, 2], indexed_slice.indices)
self.assertAllEqual([[2, 2], [2, 2], [2, 2]], gradient)
-
-class FunctionalInputLayerTest(test.TestCase):
-
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.input_layer(features={}, feature_columns=[])
+ FeatureLayer(feature_columns=[])(features={})
def test_should_be_dense_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _DenseColumn'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- ])
+ with self.assertRaisesRegexp(ValueError, 'must be a DenseColumn'):
+ FeatureLayer(feature_columns=[
+ fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ ])(
+ features={
+ 'a': [[0]]
+ })
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ FeatureLayer(feature_columns={'a': fc.numeric_column('a')})(
+ features={
+ 'a': [[0]]
+ })
def test_bare_column(self):
with ops.Graph().as_default():
features = features = {'a': [0.]}
- net = fc.input_layer(features, fc_old.numeric_column('a'))
+ net = FeatureLayer(fc.numeric_column('a'))(features)
with _initialized_session():
self.assertAllClose([[0.]], net.eval())
def test_column_generator(self):
with ops.Graph().as_default():
features = features = {'a': [0.], 'b': [1.]}
- columns = (fc_old.numeric_column(key) for key in features)
- net = fc.input_layer(features, columns)
+ columns = (fc.numeric_column(key) for key in features)
+ net = FeatureLayer(columns)(features)
with _initialized_session():
self.assertAllClose([[0., 1.]], net.eval())
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.input_layer(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ FeatureLayer(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])(
+ features={
+ 'a': [[0]]
+ })
def test_one_column(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1.], [5.]], net.eval())
def test_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.input_layer(features, [price])
+ FeatureLayer([price])(features)
def test_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
with _initialized_session():
self.assertAllClose([[1., 2.], [5., 6.]], net.eval())
def test_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session():
self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
- def test_fills_cols_to_vars(self):
- # Provide three _DenseColumn's to input_layer: a _NumericColumn, a
- # _BucketizedColumn, and an _EmbeddingColumn. Only the _EmbeddingColumn
- # creates a Variable.
- price1 = fc_old.numeric_column('price1')
- dense_feature = fc_old.numeric_column('dense_feature')
- dense_feature_bucketized = fc_old.bucketized_column(
- dense_feature, boundaries=[0.])
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
- 'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
- some_sparse_column, dimension=10)
- with ops.Graph().as_default():
- features = {
- 'price1': [[3.], [4.]],
- 'dense_feature': [[-1.], [4.]],
- 'sparse_feature': [['a'], ['x']],
- }
- cols_to_vars = {}
- all_cols = [price1, dense_feature_bucketized, some_embedding_column]
- fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
- self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
- self.assertEqual(0, len(cols_to_vars[price1]))
- self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
- self.assertEqual(1, len(cols_to_vars[some_embedding_column]))
- self.assertIsInstance(cols_to_vars[some_embedding_column][0],
- variables_lib.Variable)
- self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [5, 10])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1')
- dense_feature = fc_old.numeric_column('dense_feature')
- dense_feature_bucketized = fc_old.bucketized_column(
- dense_feature, boundaries=[0.])
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
- 'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
- some_sparse_column, dimension=10)
+ def test_cols_to_output_tensors(self):
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
- features = {
- 'price1': [[3.], [4.]],
- 'dense_feature': [[-1.], [4.]],
- 'sparse_feature': [['a'], ['x']],
- }
- cols_to_vars = {}
- all_cols = [price1, dense_feature_bucketized, some_embedding_column]
- with variable_scope.variable_scope(
- 'input_from_feature_columns',
- partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0)):
- fc.input_layer(features, all_cols, cols_to_vars=cols_to_vars)
- self.assertItemsEqual(list(cols_to_vars.keys()), all_cols)
- self.assertEqual(0, len(cols_to_vars[price1]))
- self.assertEqual(0, len(cols_to_vars[dense_feature_bucketized]))
- self.assertEqual(3, len(cols_to_vars[some_embedding_column]))
- self.assertAllEqual(cols_to_vars[some_embedding_column][0].shape, [2, 10])
- self.assertAllEqual(cols_to_vars[some_embedding_column][1].shape, [2, 10])
- self.assertAllEqual(cols_to_vars[some_embedding_column][2].shape, [1, 10])
+ cols_dict = {}
+ features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
+ feature_layer = FeatureLayer([price1, price2])
+ net = feature_layer(features, cols_dict)
+ with _initialized_session():
+ self.assertAllClose([[1., 2.], [5., 6.]], cols_dict[price1].eval())
+ self.assertAllClose([[3.], [4.]], cols_dict[price2].eval())
+ self.assertAllClose([[1., 2., 3.], [5., 6., 4.]], net.eval())
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
}
- net1 = fc.input_layer(features, [price_a, price_b])
- net2 = fc.input_layer(features, [price_b, price_a])
+ net1 = FeatureLayer([price_a, price_b])(features)
+ net2 = FeatureLayer([price_b, price_a])(features)
with _initialized_session():
self.assertAllClose([[1., 3.]], net1.eval())
self.assertAllClose([[1., 3.]], net2.eval())
def test_fails_for_categorical_column(self):
- animal = fc_old.categorical_column_with_identity('animal', num_buckets=4)
+ animal = fc.categorical_column_with_identity('animal', num_buckets=4)
with ops.Graph().as_default():
features = {
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- with self.assertRaisesRegexp(Exception, 'must be a _DenseColumn'):
- fc.input_layer(features, [animal])
+ with self.assertRaisesRegexp(Exception, 'must be a DenseColumn'):
+ FeatureLayer([animal])(features)
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -2926,12 +2864,12 @@ class FunctionalInputLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.input_layer(features, [price1, price2])
+ FeatureLayer([price1, price2])(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -2941,31 +2879,31 @@ class FunctionalInputLayerTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.input_layer(features, [price1, price2, price3])
+ FeatureLayer([price1, price2, price3])(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'Dimensions of inputs should match'):
sess.run(net, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- net = fc.input_layer(features, [price1, price2])
+ net = FeatureLayer([price1, price2])(features)
with _initialized_session() as sess:
sess.run(
net,
@@ -2975,9 +2913,9 @@ class FunctionalInputLayerTest(test.TestCase):
})
def test_multiple_layers_with_same_embedding_column(self):
- some_sparse_column = fc_old.categorical_column_with_hash_bucket(
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
'sparse_feature', hash_bucket_size=5)
- some_embedding_column = fc_old.embedding_column(
+ some_embedding_column = fc.embedding_column(
some_sparse_column, dimension=10)
with ops.Graph().as_default():
@@ -2985,28 +2923,30 @@ class FunctionalInputLayerTest(test.TestCase):
'sparse_feature': [['a'], ['x']],
}
all_cols = [some_embedding_column]
- fc.input_layer(features, all_cols)
- fc.input_layer(features, all_cols)
+ FeatureLayer(all_cols)(features)
+ FeatureLayer(all_cols)(features)
# Make sure that 2 variables get created in this case.
self.assertEqual(2, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
expected_var_names = [
- 'input_layer/sparse_feature_embedding/embedding_weights:0',
- 'input_layer_1/sparse_feature_embedding/embedding_weights:0'
+ 'feature_layer/sparse_feature_embedding/embedding_weights:0',
+ 'feature_layer_1/sparse_feature_embedding/embedding_weights:0'
]
self.assertItemsEqual(
expected_var_names,
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_multiple_layers_with_same_shared_embedding_column(self):
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=3)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
with ops.Graph().as_default():
features = {
@@ -3022,27 +2962,33 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
all_cols = [embedding_column_a, embedding_column_b]
- fc.input_layer(features, all_cols)
- fc.input_layer(features, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager)(
+ features)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager)(
+ features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ ['shared_feature_layer/aaa_bbb_shared_embedding:0'],
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_multiple_layers_with_same_shared_embedding_column_diff_graphs(self):
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=3)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc_old.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
all_cols = [embedding_column_a, embedding_column_b]
with ops.Graph().as_default():
+ shared_state_manager1 = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
features = {
'aaa':
sparse_tensor.SparseTensor(
@@ -3055,12 +3001,16 @@ class FunctionalInputLayerTest(test.TestCase):
values=(1, 2, 1),
dense_shape=(2, 2)),
}
- fc.input_layer(features, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager1)(
+ features)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
with ops.Graph().as_default():
+ shared_state_manager2 = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
features1 = {
'aaa':
sparse_tensor.SparseTensor(
@@ -3074,12 +3024,14 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2, 2)),
}
- fc.input_layer(features1, all_cols)
+ FeatureLayer(
+ all_cols, shared_state_manager=shared_state_manager2)(
+ features1)
# Make sure that only 1 variable gets created in this case.
self.assertEqual(1, len(
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
+ ['shared_feature_layer/aaa_bbb_shared_embedding:0'],
[v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
def test_with_numpy_input_fn(self):
@@ -3092,14 +3044,14 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- # one_hot_body_style has 3 dims in input_layer.
- one_hot_body_style = fc_old.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- embedded_body_style = fc_old.embedding_column(
+ # one_hot_body_style has 3 dims in feature_layer.
+ one_hot_body_style = fc.indicator_column(body_style)
+ # embedded_body_style has 5 dims in feature_layer.
+ embedded_body_style = fc.embedding_column(
body_style, dimension=5, initializer=_initializer)
input_fn = numpy_io.numpy_input_fn(
@@ -3110,8 +3062,8 @@ class FunctionalInputLayerTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_body_style])
+ net = FeatureLayer([price, one_hot_body_style, embedded_body_style])(
+ features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
@@ -3137,18 +3089,18 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
- # one_hot_body_style has 3 dims in input_layer.
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # one_hot_body_style has 3 dims in feature_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- one_hot_body_style = fc_old.indicator_column(body_style)
+ one_hot_body_style = fc.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- country = fc_old.categorical_column_with_vocabulary_list(
+ # embedded_body_style has 5 dims in feature_layer.
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
- embedded_country = fc_old.embedding_column(
+ embedded_country = fc.embedding_column(
country, dimension=5, initializer=_initializer)
# Provides 1-dim tensor and dense tensor.
@@ -3165,8 +3117,7 @@ class FunctionalInputLayerTest(test.TestCase):
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
self.assertEqual(1, features['country'].shape.ndims)
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_country])
+ net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
@@ -3187,18 +3138,18 @@ class FunctionalInputLayerTest(test.TestCase):
del shape, dtype, partition_info
return embedding_values
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
- # one_hot_body_style has 3 dims in input_layer.
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ # one_hot_body_style has 3 dims in feature_layer.
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- one_hot_body_style = fc_old.indicator_column(body_style)
+ one_hot_body_style = fc.indicator_column(body_style)
- # embedded_body_style has 5 dims in input_layer.
- country = fc_old.categorical_column_with_vocabulary_list(
+ # embedded_body_style has 5 dims in feature_layer.
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
- embedded_country = fc_old.embedding_column(
+ embedded_country = fc.embedding_column(
country, dimension=2, initializer=_initializer)
# Provides 1-dim tensor and dense tensor.
@@ -3219,8 +3170,7 @@ class FunctionalInputLayerTest(test.TestCase):
dense_shape=(2,))
country_data = np.array([['US'], ['CA']])
- net = fc.input_layer(features,
- [price, one_hot_body_style, embedded_country])
+ net = FeatureLayer([price, one_hot_body_style, embedded_country])(features)
self.assertEqual(1 + 3 + 2, net.shape[1])
with _initialized_session() as sess:
@@ -3237,8 +3187,8 @@ class FunctionalInputLayerTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- # price has 1 dimension in input_layer
- price = fc_old.numeric_column('price')
+ # price has 1 dimension in feature_layer
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -3246,13 +3196,13 @@ class FunctionalInputLayerTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.input_layer(features, [price])
+ FeatureLayer([price])(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.input_layer(features, [price])
+ net = FeatureLayer([price])(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
@@ -3267,7 +3217,7 @@ class MakeParseExampleSpecTest(test.TestCase):
@property
def name(self):
- return "_TestFeatureColumn"
+ return '_TestFeatureColumn'
def transform_feature(self, transformation_cache, state_manager):
pass
@@ -3593,25 +3543,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_vocabulary_file(
- key='aaa',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size)
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa',
@@ -4043,24 +3974,6 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'))
- inputs = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa',
@@ -4356,22 +4269,6 @@ class IdentityCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_tensor.eval())
- def DISABLED_test_get_sparse_tensors_weight_collections(self):
- column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 1, 0),
- dense_shape=(2, 2))
- column.get_sparse_tensors(
- FeatureTransformationCache({
- 'aaa': inputs
- }),
- weight_collections=('my_weights',))
-
- self.assertItemsEqual(
- [], ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- self.assertItemsEqual([], ops.get_collection('my_weights'))
-
def test_get_sparse_tensors_dense_input(self):
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
id_weight_pair = column.get_sparse_tensors(
@@ -4765,16 +4662,16 @@ class IndicatorColumnTest(test.TestCase):
weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
self.assertAllClose([[2. + 3.]], predictions.eval())
- def test_input_layer(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ def test_feature_layer(self):
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
sparse_tensor.SparseTensor(
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- net = fc.input_layer(features, [animal])
+ net = FeatureLayer([animal])(features)
with _initialized_session():
self.assertAllClose([[0., 1., 1., 0.]], net.eval())
@@ -4786,12 +4683,13 @@ class _TestStateManager(StateManager):
self._all_variables = {}
self._trainable = trainable
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
+ def create_variable(self,
+ feature_column,
+ name,
+ shape,
+ dtype=None,
+ trainable=True,
+ initializer=None):
if feature_column not in self._all_variables:
self._all_variables[feature_column] = {}
var_dict = self._all_variables[feature_column]
@@ -4801,11 +4699,19 @@ class _TestStateManager(StateManager):
var = variable_scope.get_variable(
name=name,
shape=shape,
- initializer=initializer,
- trainable=self._trainable)
+ dtype=dtype,
+ trainable=self._trainable and trainable,
+ initializer=initializer)
var_dict[name] = var
return var
+ def get_variable(self, feature_column, name):
+ if feature_column not in self._all_variables:
+ raise ValueError('Do not recognize FeatureColumn.')
+ if name in self._all_variables[feature_column]:
+ return self._all_variables[feature_column][name]
+ raise ValueError('Could not find variable.')
+
class EmbeddingColumnTest(test.TestCase):
@@ -4967,6 +4873,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5028,6 +4935,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5043,36 +4951,6 @@ class EmbeddingColumnTest(test.TestCase):
self.assertAllEqual(embedding_values, global_vars[0].eval())
self.assertAllEqual(expected_lookups, embedding_lookup.eval())
- def DISABLED_test_get_dense_tensor_weight_collections(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 5))
-
- # Build columns.
- categorical_column = fc.categorical_column_with_identity(
- key='aaa', num_buckets=3)
- embedding_column = fc.embedding_column(categorical_column, dimension=2)
-
- # Provide sparse input and get dense result.
- embedding_column.get_dense_tensor(
- FeatureTransformationCache({
- 'aaa': sparse_input
- }),
- weight_collections=('my_vars',))
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(('embedding_weights:0',),
- tuple([v.name for v in global_vars]))
- my_vars = ops.get_collection('my_vars')
- self.assertItemsEqual(
- ('embedding_weights:0',), tuple([v.name for v in my_vars]))
-
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
@@ -5117,6 +4995,7 @@ class EmbeddingColumnTest(test.TestCase):
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
input_indices = array_ops.placeholder(dtype=dtypes.int64)
@@ -5187,6 +5066,7 @@ class EmbeddingColumnTest(test.TestCase):
ckpt_to_load_from=ckpt_path,
tensor_name_in_ckpt=ckpt_tensor)
state_manager = _TestStateManager()
+ embedding_column.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup = embedding_column.get_dense_tensor(
@@ -5354,7 +5234,7 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_input_layer(self):
+ def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
@@ -5392,30 +5272,29 @@ class EmbeddingColumnTest(test.TestCase):
)
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
# Provide sparse input and get dense result.
- input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+ l = FeatureLayer((embedding_column,))
+ feature_layer = l({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in trainable_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in trainable_vars]))
with _initialized_session():
self.assertAllEqual(embedding_values, trainable_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
- def test_input_layer_not_trainable(self):
+ def test_feature_layer_not_trainable(self):
# Inputs.
vocabulary_size = 3
sparse_input = sparse_tensor.SparseTensorValue(
@@ -5453,65 +5332,26 @@ class EmbeddingColumnTest(test.TestCase):
)
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer,
trainable=False)
# Provide sparse input and get dense result.
- input_layer = fc.input_layer({'aaa': sparse_input}, (embedding_column,))
+ feature_layer = FeatureLayer((embedding_column,))({'aaa': sparse_input})
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_embedding/embedding_weights:0',),
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual(('feature_layer/aaa_embedding/embedding_weights:0',),
+ tuple([v.name for v in global_vars]))
self.assertItemsEqual(
[], ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
with _initialized_session():
self.assertAllEqual(embedding_values, global_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
-
-
-class _TestSharedEmbeddingStateManager(StateManager):
- """Manages the state for shared embedding columns.
-
- This can handle multiple groups of shared embedding columns.
- """
-
- def __init__(self, trainable=True):
- # Dict of shared_embedding_collection_name to a dict of variables.
- self._all_variables = {}
- self._trainable = trainable
-
- def get_variable(self,
- feature_column,
- name,
- shape,
- dtype=None,
- initializer=None):
- if not isinstance(feature_column, fc.SharedEmbeddingColumn):
- raise ValueError(
- 'SharedEmbeddingStateManager can only handle SharedEmbeddingColumns. '
- 'Given type: {} '.format(type(feature_column)))
-
- collection_name = feature_column.shared_collection_name
- if collection_name not in self._all_variables:
- self._all_variables[collection_name] = {}
- var_dict = self._all_variables[collection_name]
- if name in var_dict:
- return var_dict[name]
- else:
- var = variable_scope.get_variable(
- name=name,
- shape=shape,
- initializer=initializer,
- trainable=self._trainable)
- var_dict[name] = var
- return var
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
class SharedEmbeddingColumnTest(test.TestCase):
@@ -5522,7 +5362,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_b, embedding_column_a = fc.shared_embedding_columns(
+ embedding_column_b, embedding_column_a = fc.shared_embedding_columns_v2(
[categorical_column_b, categorical_column_a],
dimension=embedding_dimension)
self.assertIs(categorical_column_a, embedding_column_a.categorical_column)
@@ -5560,7 +5400,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -5605,7 +5445,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
embedding_dimension = 2
- original_a, _ = fc.shared_embedding_columns(
+ original_a, _ = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
combiner='my_combiner',
@@ -5613,7 +5453,8 @@ class SharedEmbeddingColumnTest(test.TestCase):
shared_embedding_collection_name='shared_embedding_collection_name',
ckpt_to_load_from='my_ckpt',
tensor_name_in_ckpt='my_ckpt_tensor',
- max_norm=42., trainable=False)
+ max_norm=42.,
+ trainable=False)
for embedding_column_a in (original_a, copy.deepcopy(original_a)):
self.assertEqual('aaa', embedding_column_a.categorical_column.name)
self.assertEqual(3, embedding_column_a.categorical_column.num_buckets)
@@ -5642,8 +5483,9 @@ class SharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=3)
with self.assertRaisesRegexp(ValueError, 'initializer must be callable'):
- fc.shared_embedding_columns(
- [categorical_column_a, categorical_column_b], dimension=2,
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b],
+ dimension=2,
initializer='not_fn')
def test_incompatible_column_type(self):
@@ -5656,7 +5498,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError, 'all categorical_columns must have the same type.*'
'IdentityCategoricalColumn.*HashedCategoricalColumn'):
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b, categorical_column_c],
dimension=2)
@@ -5669,11 +5511,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='bbb', num_buckets=3)
weighted_categorical_column_b = fc.weighted_categorical_column(
categorical_column_b, weight_feature_key='bbb_weights')
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[weighted_categorical_column_a, categorical_column_b], dimension=2)
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[categorical_column_a, weighted_categorical_column_b], dimension=2)
- fc.shared_embedding_columns(
+ fc.shared_embedding_columns_v2(
[weighted_categorical_column_a, weighted_categorical_column_b],
dimension=2)
@@ -5682,8 +5524,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', vocabulary_list=('omar', 'stringer', 'marlo'))
b = fc.categorical_column_with_vocabulary_list(
key='bbb', vocabulary_list=('omar', 'stringer', 'marlo'))
- a_embedded, b_embedded = fc.shared_embedding_columns(
- [a, b], dimension=2)
+ a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2)
data = example_pb2.Example(features=feature_pb2.Features(
feature={
'aaa':
@@ -5717,8 +5558,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
def test_transform_feature(self):
a = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
b = fc.categorical_column_with_identity(key='bbb', num_buckets=3)
- a_embedded, b_embedded = fc.shared_embedding_columns(
- [a, b], dimension=2)
+ a_embedded, b_embedded = fc.shared_embedding_columns_v2([a, b], dimension=2)
features = {
'aaa': sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -5788,10 +5628,13 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
- dimension=embedding_dimension, initializer=_initializer)
- state_manager = _TestSharedEmbeddingStateManager()
+ dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = fc.SharedEmbeddingStateManager(name='shared_feature_layer')
+ embedding_column_a.create_state(state_manager)
+ embedding_column_b.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
@@ -5801,7 +5644,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(('embedding_weights:0',),
+ self.assertItemsEqual(('shared_feature_layer/aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
embedding_var = global_vars[0]
with _initialized_session():
@@ -5809,58 +5652,6 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertAllEqual(expected_lookups_a, embedding_lookup_a.eval())
self.assertAllEqual(expected_lookups_b, embedding_lookup_b.eval())
- def DISABLED_test_get_dense_tensor_weight_collections(self):
- # Inputs.
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
- input_features = {'aaa': input_a, 'bbb': input_b}
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_values = (
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return embedding_values
-
- # Build columns.
- categorical_column_a = fc.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- fc.input_layer(
- input_features, [embedding_column_a, embedding_column_b],
- weight_collections=('my_vars',))
-
- # Assert expected embedding variable and lookups.
- global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
- tuple(v.name for v in global_vars))
- my_vars = ops.get_collection('my_vars')
- self.assertItemsEqual(
- ('input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
- tuple(v.name for v in my_vars))
-
def test_get_dense_tensor_placeholder_inputs(self):
# Inputs.
vocabulary_size = 3
@@ -5903,10 +5694,13 @@ class SharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
- dimension=embedding_dimension, initializer=_initializer)
- state_manager = _TestSharedEmbeddingStateManager()
+ dimension=embedding_dimension,
+ initializer=_initializer)
+ state_manager = fc.SharedEmbeddingStateManager()
+ embedding_column_a.create_state(state_manager)
+ embedding_column_b.create_state(state_manager)
# Provide sparse input and get dense result.
embedding_lookup_a = embedding_column_a.get_dense_tensor(
@@ -6096,7 +5890,7 @@ class SharedEmbeddingColumnTest(test.TestCase):
# = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
- def _test_input_layer(self, trainable=True):
+ def _test_feature_layer(self, trainable=True):
# Inputs.
vocabulary_size = 3
sparse_input_a = sparse_tensor.SparseTensorValue(
@@ -6111,6 +5905,18 @@ class SharedEmbeddingColumnTest(test.TestCase):
indices=((0, 0),),
values=(0,),
dense_shape=(2, 5))
+ sparse_input_c = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 1), (1, 1), (1, 3)),
+ values=(2, 0, 1),
+ dense_shape=(2, 5))
+ sparse_input_d = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids []
+ indices=((0, 1),),
+ values=(2,),
+ dense_shape=(2, 5))
# Embedding variable.
embedding_dimension = 2
@@ -6130,51 +5936,127 @@ class SharedEmbeddingColumnTest(test.TestCase):
# example 0:
# A ids [2], embedding = [7, 11]
# B ids [0], embedding = [1, 2]
- (7., 11., 1., 2.),
+ # C ids [2], embedding = [7, 11]
+ # D ids [2], embedding = [7, 11]
+ (7., 11., 1., 2., 7., 11., 7., 11.),
# example 1:
# A ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
# B ids [], embedding = [0, 0]
- (2., 3.5, 0., 0.),
+ # C ids [0, 1], embedding = mean([1, 2] + [3, 5]) = [2, 3.5]
+ # D ids [], embedding = [0, 0]
+ (2., 3.5, 0., 0., 2., 3.5, 0., 0.),
)
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ categorical_column_c = fc.categorical_column_with_identity(
+ key='ccc', num_buckets=vocabulary_size)
+ categorical_column_d = fc.categorical_column_with_identity(
+ key='ddd', num_buckets=vocabulary_size)
+
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer,
trainable=trainable)
+ embedding_column_c, embedding_column_d = fc.shared_embedding_columns_v2(
+ [categorical_column_c, categorical_column_d],
+ dimension=embedding_dimension,
+ initializer=_initializer,
+ trainable=trainable)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+
+ features = {
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b,
+ 'ccc': sparse_input_c,
+ 'ddd': sparse_input_d
+ }
# Provide sparse input and get dense result.
- input_layer = fc.input_layer(
- features={'aaa': sparse_input_a, 'bbb': sparse_input_b},
- feature_columns=(embedding_column_b, embedding_column_a))
+ feature_layer = FeatureLayer(
+ feature_columns=(embedding_column_b, embedding_column_a,
+ embedding_column_c, embedding_column_d),
+ shared_state_manager=shared_state_manager)(
+ features)
# Assert expected embedding variable and lookups.
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
- tuple([v.name for v in global_vars]))
+ self.assertItemsEqual([
+ 'shared_feature_layer/aaa_bbb_shared_embedding:0',
+ 'shared_feature_layer/ccc_ddd_shared_embedding:0'
+ ], tuple([v.name for v in global_vars]))
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
if trainable:
- self.assertItemsEqual(
- ['input_layer/aaa_bbb_shared_embedding/embedding_weights:0'],
- tuple([v.name for v in trainable_vars]))
+ self.assertItemsEqual([
+ 'shared_feature_layer/aaa_bbb_shared_embedding:0',
+ 'shared_feature_layer/ccc_ddd_shared_embedding:0'
+ ], tuple([v.name for v in trainable_vars]))
else:
self.assertItemsEqual([], tuple([v.name for v in trainable_vars]))
shared_embedding_vars = global_vars
with _initialized_session():
self.assertAllEqual(embedding_values, shared_embedding_vars[0].eval())
- self.assertAllEqual(expected_lookups, input_layer.eval())
+ self.assertAllEqual(expected_lookups, feature_layer.eval())
+
+ def test_feature_layer(self):
+ self._test_feature_layer()
+
+ def test_feature_layer_no_trainable(self):
+ self._test_feature_layer(trainable=False)
+
- def test_input_layer(self):
- self._test_input_layer()
+class SharedEmbeddingStateManagerTest(test.TestCase):
- def test_input_layer_no_trainable(self):
- self._test_input_layer(trainable=False)
+ def test_basic(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+ var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ var_b = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ self.assertEqual(var_a, var_b)
+ self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0',
+ var_a.name)
+ self.assertIsInstance(var_a, variables_lib.Variable)
+
+ def test_multiple_sets(self):
+ categorical_column_a = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ categorical_column_b = fc.categorical_column_with_identity(
+ key='bbb', num_buckets=3)
+ categorical_column_c = fc.categorical_column_with_identity(
+ key='ccc', num_buckets=3)
+ categorical_column_d = fc.categorical_column_with_identity(
+ key='ddd', num_buckets=3)
+
+ fc.shared_embedding_columns_v2(
+ [categorical_column_a, categorical_column_b], dimension=2)
+ fc.shared_embedding_columns_v2(
+ [categorical_column_c, categorical_column_d], dimension=2)
+ shared_state_manager = fc.SharedEmbeddingStateManager(
+ name='shared_feature_layer')
+ var_a = shared_state_manager.create_variable('aaa_bbb_shared_embedding',
+ [5, 10])
+ var_c = shared_state_manager.create_variable('ccc_ddd_shared_embedding',
+ [5, 10])
+ self.assertIsInstance(var_a, variables_lib.Variable)
+ self.assertIsInstance(var_c, variables_lib.Variable)
+ self.assertNotEquals(var_a, var_c)
+ self.assertEqual('shared_feature_layer/aaa_bbb_shared_embedding:0',
+ var_a.name)
+ self.assertEqual('shared_feature_layer/ccc_ddd_shared_embedding:0',
+ var_c.name)
class WeightedCategoricalColumnTest(test.TestCase):
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index ab06a2babf..06c653097a 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import copy
+import threading
from tensorflow.python.util.tf_export import tf_export
@@ -229,6 +230,12 @@ class DeviceSpec(object):
"""
return DeviceSpec().parse_from_string(spec)
+ def __eq__(self, other):
+ return self.to_string() == other.to_string()
+
+ def __hash__(self):
+ return hash(self.to_string())
+
def check_valid(spec):
"""Check that a device spec is valid.
@@ -254,6 +261,14 @@ def canonical_name(device):
return device.to_string()
+# Cache from DeviceSpec objects to their corresponding device functions.
+# This cache is maintained for correctness, not performance: it makes it
+# possible to compare the device function stacks belonging to different
+# graphs in a meaningful way.
+_cached_device_functions = {}
+_cache_lock = threading.Lock()
+
+
def merge_device(spec):
"""Returns a device function that merges devices specifications.
@@ -280,11 +295,18 @@ def merge_device(spec):
Raises:
ValueError: if the spec was not valid.
"""
- if not isinstance(spec, DeviceSpec):
- spec = DeviceSpec.from_string(spec or "")
- def _device_function(node_def):
- current_device = DeviceSpec.from_string(node_def.device or "")
- copy_spec = copy.copy(spec)
- copy_spec.merge_from(current_device) # current_device takes precedence.
- return copy_spec
- return _device_function
+ with _cache_lock:
+ if not isinstance(spec, DeviceSpec):
+ spec = DeviceSpec.from_string(spec or "")
+ cached_function = _cached_device_functions.get(spec, None)
+ if cached_function is not None:
+ return cached_function
+
+ def _device_function(node_def):
+ current_device = DeviceSpec.from_string(node_def.device or "")
+ copy_spec = copy.copy(spec)
+ copy_spec.merge_from(current_device) # current_device takes precedence.
+ return copy_spec
+
+ _cached_device_functions[spec] = _device_function
+ return _device_function
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index f47c0d8a5e..a8aef3a009 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import hashlib
-import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
@@ -34,7 +33,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.util import compat
@@ -42,9 +40,6 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
-# This is to avoid a circular dependency with cond_v2_impl.
-cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-access
-
class Defun(object):
"""Decorator used to define TensorFlow functions.
@@ -1029,20 +1024,10 @@ def _from_definition(fdef, grad_func=None):
result = _DefinedFunction(func, argnames, input_types, func_name, grad_func,
python_grad_func, out_names)
# pylint: disable=protected-access
- if ops._USE_C_API:
- serialized = fdef.SerializeToString()
- c_func = c_api.TF_FunctionImportFunctionDef(serialized)
- result._c_func = c_api_util.ScopedTFFunction(c_func)
- result._extra_inputs = []
- else:
- result._definition = fdef
- # Captured inputs are added as regular inputs to a function when it's
- # serialized, i.e. any extra inputs from the original function are now
- # included in `result`._args
- result._extra_inputs = []
- result._hash_str = result._create_hash_str(
- result._definition.signature.input_arg,
- result._definition.signature.output_arg, result._definition.node_def)
+ serialized = fdef.SerializeToString()
+ c_func = c_api.TF_FunctionImportFunctionDef(serialized)
+ result._c_func = c_api_util.ScopedTFFunction(c_func)
+ result._extra_inputs = []
# pylint: enable=protected-access
return result
diff --git a/tensorflow/python/framework/function_def_to_graph.py b/tensorflow/python/framework/function_def_to_graph.py
index 1b09506662..a04fa369ae 100644
--- a/tensorflow/python/framework/function_def_to_graph.py
+++ b/tensorflow/python/framework/function_def_to_graph.py
@@ -23,7 +23,7 @@ import sys
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
-from tensorflow.python.framework import function
+from tensorflow.python.eager import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
@@ -34,13 +34,13 @@ cond_v2_impl._function_def_to_graph = sys.modules[__name__] # pylint: disable=p
def function_def_to_graph(fdef, input_shapes=None):
- """Converts a FunctionDef to a function._FuncGraph (sub-class Graph).
+ """Converts a FunctionDef to a function.FuncGraph (sub-class Graph).
- The returned _FuncGraph's `name`, `inputs` and `outputs` fields will be set.
+ The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set.
The input tensors are represented as placeholders.
- Note: `_FuncGraph.inputs` and `_FuncGraph._captured` are not set and may be
- set by the caller.
+ Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set
+ by the caller.
Args:
fdef: FunctionDef.
@@ -50,9 +50,9 @@ def function_def_to_graph(fdef, input_shapes=None):
placeholder will have unknown shape.
Returns:
- A _FuncGraph.
+ A FuncGraph.
"""
- func_graph = function._FuncGraph(fdef.signature.name, capture_by_value=False) # pylint: disable=protected-access
+ func_graph = function.FuncGraph(fdef.signature.name)
graph_def, nested_to_flat_tensor_name = function_def_to_graph_def(
fdef, input_shapes)
@@ -60,7 +60,7 @@ def function_def_to_graph(fdef, input_shapes=None):
# Add all function nodes to the graph.
importer.import_graph_def(graph_def, name="")
- # Initialize fields specific to _FuncGraph.
+ # Initialize fields specific to FuncGraph.
# inputs
input_tensor_names = [
@@ -144,6 +144,8 @@ def function_def_to_graph_def(fdef, input_shapes=None):
for arg_def in fdef.signature.input_arg:
nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)
+ control_name = "^" + arg_def.name
+ nested_to_flat_tensor_name[control_name] = control_name
for node_def in fdef.node_def:
op_def = ops.get_default_graph()._get_op_def(node_def.op) # pylint: disable=protected-access
@@ -172,6 +174,8 @@ def function_def_to_graph_def(fdef, input_shapes=None):
flat_name = "{}:{}".format(node_def.name, flattened_index)
nested_to_flat_tensor_name[nested_name] = flat_name
flattened_index += 1
+ control_name = "^" + node_def.name
+ nested_to_flat_tensor_name[control_name] = control_name
# Update inputs of all nodes in graph.
for node_def in graph_def.node:
diff --git a/tensorflow/python/framework/function_def_to_graph_test.py b/tensorflow/python/framework/function_def_to_graph_test.py
index 21d2c7d990..e013fb6e4d 100644
--- a/tensorflow/python/framework/function_def_to_graph_test.py
+++ b/tensorflow/python/framework/function_def_to_graph_test.py
@@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import graph_to_function_def
from tensorflow.python.framework import ops
@@ -154,14 +154,20 @@ class FunctionDefToGraphDefTest(test.TestCase):
self.assertDictEqual(
tensor_name_map, {
"x": "x:0",
+ "^x": "^x",
"y": "y:0",
+ "^y": "^y",
"z": "z:0",
+ "^z": "^z",
"foo_1:d:0": "foo_1:0",
"foo_1:e:0": "foo_1:1",
+ "^foo_1": "^foo_1",
"list_output:a:0": "list_output:0",
"list_output:a:1": "list_output:1",
+ "^list_output": "^list_output",
"foo_2:d:0": "foo_2:0",
"foo_2:e:0": "foo_2:1",
+ "^foo_2": "^foo_2",
})
def testShapes(self):
@@ -184,23 +190,26 @@ class FunctionDefToGraphDefTest(test.TestCase):
x = constant_op.constant(5.0)
y = constant_op.constant(10.0)
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def inner_fn():
return x + y
return inner_fn()
- # Instantiate the function in this graph so that
- # `function_def_to_graph` can find it.
- fn()
-
+ @function.defun
def fn2():
return 2 * fn()
- fdef = function._DefinedFunction(fn2, [], []).definition
+ fn2_defun = fn2.get_concrete_function()
+
+ # Call `fn2` to make sure `fn` is correctly instantiated so
+ # `function_def_to_graph` can find it.
+ fn2_defun()
+
+ fdef = fn2_defun._inference_function.definition
func_graph = function_def_to_graph.function_def_to_graph(fdef)
with func_graph.as_default():
x_ph, y_ph = func_graph.inputs
@@ -211,6 +220,26 @@ class FunctionDefToGraphDefTest(test.TestCase):
y_ph: 10.0
}), 30.0)
+ def testControlDependencies(self):
+
+ @function.defun
+ def fn(inp):
+ x = constant_op.constant(2.0, name="x")
+ # TODO(b/79881896): Test external control dependency once that's
+ # supported.
+ with ops.control_dependencies([x, inp]):
+ constant_op.constant(3.0, name="y")
+ return 4.0
+
+ inp = constant_op.constant(1.0)
+ fdef = fn.get_concrete_function(inp).function_def
+ func_graph = function_def_to_graph.function_def_to_graph(fdef)
+
+ op = func_graph.get_operation_by_name("y")
+ self.assertEqual(len(op.control_inputs), 2)
+ self.assertEqual(op.control_inputs[0].name, "x")
+ self.assertEqual(op.control_inputs[1].name, "placeholder")
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 21eb306865..8d72eb39c0 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import copy
-import os
import re
import sys
import threading
@@ -67,7 +66,7 @@ from tensorflow.python.util.tf_export import tf_export
# Temporary global switches determining if we should enable the work-in-progress
# calls to the C API. These will be removed once all functionality is supported.
_USE_C_API = True
-_USE_C_SHAPES = os.getenv("TF_C_API_GRAPH_CONSTRUCTION_SHAPES", "1") != "0"
+_USE_C_SHAPES = True
def tensor_id(tensor):
@@ -516,6 +515,11 @@ class Tensor(_TensorLike):
==> TensorShape([Dimension(28), Dimension(28), Dimension(3)])
```
+ NOTE: This shape is not enforced at runtime. Setting incorrect shapes can
+ result in inconsistencies between the statically-known graph and the runtime
+ value of tensors. For runtime validation of the shape, use `tf.ensure_shape`
+ instead.
+
Args:
shape: A `TensorShape` representing the shape of this tensor, a
`TensorShapeProto`, a list, a tuple, or None.
@@ -753,6 +757,9 @@ class _EagerTensorBase(Tensor):
def __format__(self, format_spec):
return self.numpy().__format__(format_spec)
+ def __reduce__(self):
+ return (convert_to_tensor, (self.numpy(),))
+
def _numpy(self):
raise NotImplementedError()
@@ -2856,19 +2863,11 @@ class Graph(object):
# TODO(skyewm): fold as much of the above as possible into the C
# implementation
- if self._use_c_api_hack():
- self._scoped_c_graph = c_api_util.ScopedTFGraph()
- # The C API requires all ops to have shape functions. Disable this
- # requirement (many custom ops do not have shape functions, and we don't
- # want to break these existing cases).
- c_api.SetRequireShapeInferenceFns(self._c_graph, False)
- else:
- self._scoped_c_graph = None
-
- # TODO(apassos) remove once the C API is used by default.
- def _use_c_api_hack(self):
- """Temporary hack; can be overridden to force C API usage."""
- return _USE_C_API
+ self._scoped_c_graph = c_api_util.ScopedTFGraph()
+ # The C API requires all ops to have shape functions. Disable this
+ # requirement (many custom ops do not have shape functions, and we don't
+ # want to break these existing cases).
+ c_api.SetRequireShapeInferenceFns(self._c_graph, False)
# Note: this method is private because the API of tf.Graph() is public and
# frozen, and this functionality is still not ready for public visibility.
@@ -3118,7 +3117,7 @@ class Graph(object):
Returns:
bool indicating whether or not 'name' is registered in function library.
"""
- return name in self._functions
+ return compat.as_str(name) in self._functions
def _get_function(self, name):
"""Returns the function definition for 'name'.
@@ -3128,7 +3127,7 @@ class Graph(object):
Returns:
The function def proto.
"""
- return self._functions.get(name, None)
+ return self._functions.get(compat.as_str(name), None)
def _add_function(self, function):
"""Adds a function to the graph.
@@ -3164,7 +3163,7 @@ class Graph(object):
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func.func, gradient)
# pylint: enable=protected-access
- self._functions[name] = function
+ self._functions[compat.as_str(name)] = function
# Need a new-enough consumer to support the functions we add to the graph.
if self._graph_def_versions.min_consumer < 12:
@@ -5223,6 +5222,7 @@ _default_graph_stack = _DefaultGraphStack()
# pylint: disable=g-doc-return-or-yield,line-too-long
+@tf_export("init_scope")
@tf_contextlib.contextmanager
def init_scope():
"""A context manager that lifts ops out of control-flow scopes and function-building graphs.
@@ -5252,6 +5252,23 @@ def init_scope():
(3) The gradient tape is paused while the scope is active.
+ When eager execution is enabled, code inside an init_scope block runs with
+ eager execution enabled even when defining graph functions via
+ tf.contrib.eager.defun. For example:
+
+ ```python
+ tf.enable_eager_execution()
+
+ @tf.contrib.eager.defun
+ def func():
+ # A defun-decorated function constructs TensorFlow graphs,
+ # it does not execute eagerly.
+ assert not tf.executing_eagerly()
+ with tf.init_scope():
+ # Initialization runs with eager execution enabled
+ assert tf.executing_eagerly()
+ ```
+
Raises:
RuntimeError: if graph state is incompatible with this initialization.
"""
@@ -5382,11 +5399,12 @@ def enable_eager_execution(config=None,
TensorFlow graph, or if options provided conflict with a previous call
to this function.
"""
- return enable_eager_execution_internal(
- config=config,
- device_policy=device_policy,
- execution_mode=execution_mode,
- server_def=None)
+ if context._default_mode != context.EAGER_MODE: # pylint: disable=protected-access
+ return enable_eager_execution_internal(
+ config=config,
+ device_policy=device_policy,
+ execution_mode=execution_mode,
+ server_def=None)
def enable_eager_execution_internal(config=None,
diff --git a/tensorflow/python/framework/ops_enable_eager_test.py b/tensorflow/python/framework/ops_enable_eager_test.py
new file mode 100644
index 0000000000..99d06f1c2d
--- /dev/null
+++ b/tensorflow/python/framework/ops_enable_eager_test.py
@@ -0,0 +1,38 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests enabling eager execution at process level."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import googletest
+
+
+class OpsEnableEagerTest(googletest.TestCase):
+
+ def test_enable_eager_execution_multiple_times(self):
+ ops.enable_eager_execution()
+ self.assertTrue(context.executing_eagerly())
+
+ # Calling enable eager execution a second time should not cause an error.
+ ops.enable_eager_execution()
+ self.assertTrue(context.executing_eagerly())
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index 8eb943b960..e20ad5fd33 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -52,7 +52,7 @@ Status ReadOpListFromFile(const string& filename,
if (scanner.One(strings::Scanner::LETTER_DIGIT_DOT)
.Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
.GetResult(nullptr, &op_name)) {
- op_list->emplace_back(op_name.ToString());
+ op_list->emplace_back(op_name);
}
s = input_buffer->ReadLine(&line_contents);
}
diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py
index 48a834392b..7ee2b5b347 100644
--- a/tensorflow/python/framework/smart_cond.py
+++ b/tensorflow/python/framework/smart_cond.py
@@ -77,11 +77,9 @@ def smart_constant_value(pred):
pred_value = pred
elif isinstance(pred, ops.Tensor):
pred_value = tensor_util.constant_value(pred)
- # TODO(skyewm): consider folding this into tensor_util.constant_value when
- # _USE_C_API is removed (there may be performance and correctness bugs, so I
- # wanted to limit the change hidden behind _USE_C_API).
+ # TODO(skyewm): consider folding this into tensor_util.constant_value.
# pylint: disable=protected-access
- if pred_value is None and ops._USE_C_API:
+ if pred_value is None:
pred_value = c_api.TF_TryEvaluateConstant_wrapper(pred.graph._c_graph,
pred._as_tf_output())
# pylint: enable=protected-access
diff --git a/tensorflow/python/framework/subscribe.py b/tensorflow/python/framework/subscribe.py
index cee7398974..00759eb611 100644
--- a/tensorflow/python/framework/subscribe.py
+++ b/tensorflow/python/framework/subscribe.py
@@ -137,12 +137,7 @@ def _subscribe_new(tensor, side_effects, control_cache):
# are subscribed at the same time, we remove the control dependency from
# the original op only once and we add the dependencies to all the
# new identities.
- if ops._USE_C_API: # pylint: disable=protected-access
- new_control_inputs = consumer_op.control_inputs
- else:
- # Make a copy so we don't modify the actual control inputs (this is fixed
- # in the C API).
- new_control_inputs = list(consumer_op.control_inputs)
+ new_control_inputs = consumer_op.control_inputs
if tensor.op in new_control_inputs:
new_control_inputs.remove(tensor.op)
new_control_inputs.append(out.op)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 155134fac4..7cddd861c8 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -547,7 +547,7 @@ def assert_no_new_tensors(f):
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
- context.get_default_context()._clear_caches() # pylint: disable=protected-access
+ context.context()._clear_caches() # pylint: disable=protected-access
gc.collect()
tensors_after = [
obj for obj in gc.get_objects()
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index f68c0ddecb..a0939f98b2 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -121,6 +121,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
else:
print("MKL is disabled")
+ @test_util.run_in_graph_and_eager_modes
def testAssertProtoEqualsStr(self):
graph_str = "node { name: 'w1' op: 'params' }"
@@ -133,6 +134,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
# test original comparison
self.assertProtoEquals(graph_def, graph_def)
+ @test_util.run_in_graph_and_eager_modes
def testAssertProtoEqualsAny(self):
# Test assertProtoEquals with a protobuf.Any field.
meta_graph_def_str = """
@@ -161,6 +163,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
r'meta_graph_version: "inner"'):
self.assertProtoEquals("", meta_graph_def_outer)
+ @test_util.run_in_graph_and_eager_modes
def testNDArrayNear(self):
a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@@ -168,6 +171,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertTrue(self._NDArrayNear(a1, a2, 1e-5))
self.assertFalse(self._NDArrayNear(a1, a3, 1e-5))
+ @test_util.run_in_graph_and_eager_modes
def testCheckedThreadSucceeds(self):
def noop(ev):
@@ -181,6 +185,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
t.join()
self.assertTrue(event_arg.is_set())
+ @test_util.run_in_graph_and_eager_modes
def testCheckedThreadFails(self):
def err_func():
@@ -192,6 +197,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
t.join()
self.assertTrue("integer division or modulo by zero" in str(fe.exception))
+ @test_util.run_in_graph_and_eager_modes
def testCheckedThreadWithWrongAssertionFails(self):
x = 37
@@ -204,6 +210,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
t.join()
self.assertTrue("False is not true" in str(fe.exception))
+ @test_util.run_in_graph_and_eager_modes
def testMultipleThreadsWithOneFailure(self):
def err_func(i):
@@ -232,6 +239,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
original_op=op_orig)
raise errors.UnauthenticatedError(node_def, op, "true_err")
+ @test_util.run_in_graph_and_eager_modes
def testAssertRaisesOpErrorDoesNotPassMessageDueToLeakedStack(self):
with self.assertRaises(AssertionError):
self._WeMustGoDeeper("this_is_not_the_error_you_are_looking_for")
@@ -240,6 +248,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self._WeMustGoDeeper("name")
self._WeMustGoDeeper("orig")
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseTensors(self):
a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
a = constant_op.constant(a_raw_data)
@@ -255,17 +264,20 @@ class TestUtilTest(test_util.TensorFlowTestCase):
y_list = [a_raw_data, b]
self.assertAllClose(x_list, y_list)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseScalars(self):
self.assertAllClose(7, 7 + 1e-8)
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(7, 7 + 1e-5)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseDictToNonDict(self):
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose(1, {"a": 1})
with self.assertRaisesRegexp(ValueError, r"Can't compare dict to non-dict"):
self.assertAllClose({"a": 1}, 1)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseNamedtuples(self):
a = 7
b = (2., 3.)
@@ -278,6 +290,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllClose(
my_named_tuple(a=a, b=b, c=c), my_named_tuple(a=a, b=b, c=c))
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseDicts(self):
a = 7
b = (2., 3.)
@@ -305,6 +318,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
self.assertAllClose(expected, {"a": a, "b": b, "c": c_copy})
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseListOfNamedtuples(self):
my_named_tuple = collections.namedtuple("MyNamedTuple", ["x", "y"])
l1 = [
@@ -317,6 +331,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
]
self.assertAllClose(l1, l2)
+ @test_util.run_in_graph_and_eager_modes
def testAllCloseNestedStructure(self):
a = {"x": np.ones((3, 2, 4)) * 7, "y": (2, [{"nested": {"m": 3, "n": 4}}])}
self.assertAllClose(a, a)
@@ -330,6 +345,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
r"\[y\]\[1\]\[0\]\[nested\]\[n\]"):
self.assertAllClose(a, b)
+ @test_util.run_in_graph_and_eager_modes
def testArrayNear(self):
a = [1, 2]
b = [1, 2, 5]
@@ -352,6 +368,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
y = [15]
control_flow_ops.Assert(x, y).run()
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllCloseAccordingToType(self):
# test plain int
self.assertAllCloseAccordingToType(1, 1, rtol=1e-8, atol=1e-8)
@@ -428,6 +445,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
half_rtol=1e-4, half_atol=1e-4
)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllEqual(self):
i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i")
j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j")
@@ -437,6 +455,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllEqual([120] * 3, k)
self.assertAllEqual([20] * 3, j)
+ @test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays
self.assertNotAllClose([0.1], [0.2])
@@ -453,6 +472,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertNotAllClose([1.0, 1.0], x)
+ @test_util.run_in_graph_and_eager_modes
def testAssertNotAllCloseRTol(self):
# Test with arrays
with self.assertRaises(AssertionError):
@@ -467,6 +487,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertNotAllClose([0.9, 1.0], x, rtol=0.2)
+ @test_util.run_in_graph_and_eager_modes
def testAssertNotAllCloseATol(self):
# Test with arrays
with self.assertRaises(AssertionError):
@@ -481,6 +502,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertNotAllClose([0.9, 1.0], x, atol=0.2)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllGreaterLess(self):
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
@@ -501,6 +523,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertAllLess(x, 95.0)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllGreaterLessEqual(self):
x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
@@ -533,6 +556,7 @@ class TestUtilTest(test_util.TensorFlowTestCase):
with self.assertRaises(AssertionError):
self.assertAllInRange(b, 0, 1)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRange(self):
x = constant_op.constant([10.0, 15.0], name="x")
self.assertAllInRange(x, 10, 15)
@@ -545,24 +569,28 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertAllInRange(
x, 10, 15, open_lower_bound=True, open_upper_bound=True)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeErrorMessageEllipses(self):
x_init = np.array([[10.0, 15.0]] * 12)
x = constant_op.constant(x_init, name="x")
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 5, 10)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeDetectsNaNs(self):
x = constant_op.constant(
[[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x")
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 0.0, 2.0)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInRangeWithInfinities(self):
x = constant_op.constant([10.0, np.inf], name="x")
self.assertAllInRange(x, 10, np.inf)
with self.assertRaises(AssertionError):
self.assertAllInRange(x, 10, np.inf, open_upper_bound=True)
+ @test_util.run_in_graph_and_eager_modes
def testAssertAllInSet(self):
b = constant_op.constant([True, False], name="b")
x = constant_op.constant([13, 37], name="x")
diff --git a/tensorflow/python/grappler/graph_analyzer.i b/tensorflow/python/grappler/graph_analyzer.i
new file mode 100644
index 0000000000..cc7b5358eb
--- /dev/null
+++ b/tensorflow/python/grappler/graph_analyzer.i
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+%{
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h"
+%}
+
+%{
+void GraphAnalyzer(const string& file_path, int n) {
+ tensorflow::grappler::graph_analyzer::GraphAnalyzerTool(file_path, n);
+}
+%}
+
+void GraphAnalyzer(const string& file_path, int n);
diff --git a/tensorflow/python/grappler/graph_analyzer.py b/tensorflow/python/grappler/graph_analyzer.py
new file mode 100644
index 0000000000..ec5544e38e
--- /dev/null
+++ b/tensorflow/python/grappler/graph_analyzer.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""A tool that finds all subgraphs of a given size in a TF graph.
+
+The subgraph patterns are sorted by occurrence, and only the transitive fanin
+part of the graph with regard to the fetch nodes is considered.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+from tensorflow.python import pywrap_tensorflow as tf_wrap
+from tensorflow.python.platform import app
+
+
+def main(_):
+ tf_wrap.GraphAnalyzer(FLAGS.input, FLAGS.n)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--input",
+ type=str,
+ default=None,
+ help="Input file path for a TensorFlow MetaGraphDef.")
+ parser.add_argument(
+ "--n", type=int, default=None, help="The size of the subgraphs.")
+ FLAGS, unparsed = parser.parse_known_args()
+ app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index e145b894f5..5523d70a8d 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -102,7 +102,6 @@ py_library(
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
)
@@ -140,6 +139,7 @@ py_library(
":backend",
"//tensorflow/python/data",
"//tensorflow/python/training/checkpointable:data_structures",
+ "//tensorflow/tools/docs:doc_controls",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py
index 5cff1f8f9c..dd0bbcff39 100644
--- a/tensorflow/python/keras/activations_test.py
+++ b/tensorflow/python/keras/activations_test.py
@@ -45,7 +45,7 @@ class KerasActivationsTest(test.TestCase):
assert fn == ref_fn
def test_softmax(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softmax(x)])
test_values = np.random.random((2, 5))
@@ -59,7 +59,7 @@ class KerasActivationsTest(test.TestCase):
keras.activations.softmax(x)
def test_temporal_softmax(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(shape=(2, 2, 3))
f = keras.backend.function([x], [keras.activations.softmax(x)])
test_values = np.random.random((2, 2, 3)) * 10
@@ -73,7 +73,7 @@ class KerasActivationsTest(test.TestCase):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
- with self.test_session():
+ with self.cached_session():
positive_values = np.array([[1, 2]], dtype=keras.backend.floatx())
result = f([positive_values])[0]
self.assertAllClose(result, positive_values * scale, rtol=1e-05)
@@ -87,7 +87,7 @@ class KerasActivationsTest(test.TestCase):
def softplus(x):
return np.log(np.ones_like(x) + np.exp(x))
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softplus(x)])
test_values = np.random.random((2, 5))
@@ -99,7 +99,7 @@ class KerasActivationsTest(test.TestCase):
def softsign(x):
return np.divide(x, np.ones_like(x) + np.absolute(x))
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softsign(x)])
test_values = np.random.random((2, 5))
@@ -116,7 +116,7 @@ class KerasActivationsTest(test.TestCase):
return z / (1 + z)
sigmoid = np.vectorize(ref_sigmoid)
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.sigmoid(x)])
test_values = np.random.random((2, 5))
@@ -130,7 +130,7 @@ class KerasActivationsTest(test.TestCase):
z = 0.0 if x <= 0 else (1.0 if x >= 1 else x)
return z
hard_sigmoid = np.vectorize(ref_hard_sigmoid)
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.hard_sigmoid(x)])
test_values = np.random.random((2, 5))
@@ -139,7 +139,7 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result, expected, rtol=1e-05)
def test_relu(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.relu(x)])
test_values = np.random.random((2, 5))
@@ -148,7 +148,7 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result, test_values, rtol=1e-05)
def test_elu(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.elu(x, 0.5)])
test_values = np.random.random((2, 5))
@@ -160,7 +160,7 @@ class KerasActivationsTest(test.TestCase):
self.assertAllClose(result, true_result)
def test_tanh(self):
- with self.test_session():
+ with self.cached_session():
test_values = np.random.random((2, 5))
x = keras.backend.placeholder(ndim=2)
exp = keras.activations.tanh(x)
diff --git a/tensorflow/python/keras/applications/__init__.py b/tensorflow/python/keras/applications/__init__.py
index cd9462d6b5..a8b6d55e41 100644
--- a/tensorflow/python/keras/applications/__init__.py
+++ b/tensorflow/python/keras/applications/__init__.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""Keras Applications are canned architectures with pre-trained weights."""
# pylint: disable=g-import-not-at-top
+# pylint: disable=g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -25,13 +26,49 @@ from tensorflow.python.keras import engine
from tensorflow.python.keras import layers
from tensorflow.python.keras import models
from tensorflow.python.keras import utils
+from tensorflow.python.util import tf_inspect
+
+# `get_submodules_from_kwargs` has been introduced in 1.0.5, but we would
+# like to be able to handle prior versions. Note that prior to 1.0.5,
+# `keras_applications` did not expose a `__version__` attribute.
+if not hasattr(keras_applications, 'get_submodules_from_kwargs'):
+
+ if 'engine' in tf_inspect.getfullargspec(
+ keras_applications.set_keras_submodules)[0]:
+ keras_applications.set_keras_submodules(
+ backend=backend,
+ layers=layers,
+ models=models,
+ utils=utils,
+ engine=engine)
+ else:
+ keras_applications.set_keras_submodules(
+ backend=backend,
+ layers=layers,
+ models=models,
+ utils=utils)
+
+
+def keras_modules_injection(base_fun):
+ """Decorator injecting tf.keras replacements for Keras modules.
+
+ Arguments:
+ base_fun: Application function to decorate (e.g. `MobileNet`).
+
+ Returns:
+ Decorated function that injects keyword argument for the tf.keras
+ modules required by the Applications.
+ """
+
+ def wrapper(*args, **kwargs):
+ if hasattr(keras_applications, 'get_submodules_from_kwargs'):
+ kwargs['backend'] = backend
+ kwargs['layers'] = layers
+ kwargs['models'] = models
+ kwargs['utils'] = utils
+ return base_fun(*args, **kwargs)
+ return wrapper
-keras_applications.set_keras_submodules(
- backend=backend,
- engine=engine,
- layers=layers,
- models=models,
- utils=utils)
from tensorflow.python.keras.applications.densenet import DenseNet121
from tensorflow.python.keras.applications.densenet import DenseNet169
@@ -39,7 +76,7 @@ from tensorflow.python.keras.applications.densenet import DenseNet201
from tensorflow.python.keras.applications.inception_resnet_v2 import InceptionResNetV2
from tensorflow.python.keras.applications.inception_v3 import InceptionV3
from tensorflow.python.keras.applications.mobilenet import MobileNet
-# TODO(fchollet): enable MobileNetV2 in next version.
+from tensorflow.python.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.python.keras.applications.nasnet import NASNetLarge
from tensorflow.python.keras.applications.nasnet import NASNetMobile
from tensorflow.python.keras.applications.resnet50 import ResNet50
diff --git a/tensorflow/python/keras/applications/applications_test.py b/tensorflow/python/keras/applications/applications_test.py
index ef3198a937..b15ca5990a 100644
--- a/tensorflow/python/keras/applications/applications_test.py
+++ b/tensorflow/python/keras/applications/applications_test.py
@@ -32,7 +32,8 @@ MODEL_LIST = [
(applications.InceptionV3, 2048),
(applications.InceptionResNetV2, 1536),
(applications.MobileNet, 1024),
- # TODO(fchollet): enable MobileNetV2 in next version.
+ # TODO(fchollet): enable MobileNetV2 tests when a new TensorFlow test image
+ # is released with keras_applications upgraded to 1.0.5 or above.
(applications.DenseNet121, 1024),
(applications.DenseNet169, 1664),
(applications.DenseNet201, 1920),
@@ -44,11 +45,6 @@ MODEL_LIST = [
class ApplicationsTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(*MODEL_LIST)
- def test_classification_model(self, model_fn, _):
- model = model_fn(classes=1000, weights=None)
- self.assertEqual(model.output_shape[-1], 1000)
-
- @parameterized.parameters(*MODEL_LIST)
def test_feature_extration_model(self, model_fn, output_dim):
model = model_fn(include_top=False, weights=None)
self.assertEqual(model.output_shape, (None, None, None, output_dim))
diff --git a/tensorflow/python/keras/applications/densenet.py b/tensorflow/python/keras/applications/densenet.py
index fbdcc66d2d..172848bbdb 100644
--- a/tensorflow/python/keras/applications/densenet.py
+++ b/tensorflow/python/keras/applications/densenet.py
@@ -20,18 +20,39 @@ from __future__ import division
from __future__ import print_function
from keras_applications import densenet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-DenseNet121 = densenet.DenseNet121
-DenseNet169 = densenet.DenseNet169
-DenseNet201 = densenet.DenseNet201
-decode_predictions = densenet.decode_predictions
-preprocess_input = densenet.preprocess_input
-
-tf_export('keras.applications.densenet.DenseNet121',
- 'keras.applications.DenseNet121')(DenseNet121)
-tf_export('keras.applications.densenet.DenseNet169',
- 'keras.applications.DenseNet169')(DenseNet169)
-tf_export('keras.applications.densenet.DenseNet201',
- 'keras.applications.DenseNet201')(DenseNet201)
-tf_export('keras.applications.densenet.preprocess_input')(preprocess_input)
+
+@tf_export('keras.applications.densenet.DenseNet121',
+ 'keras.applications.DenseNet121')
+@keras_modules_injection
+def DenseNet121(*args, **kwargs):
+ return densenet.DenseNet121(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.DenseNet169',
+ 'keras.applications.DenseNet169')
+@keras_modules_injection
+def DenseNet169(*args, **kwargs):
+ return densenet.DenseNet169(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.DenseNet201',
+ 'keras.applications.DenseNet201')
+@keras_modules_injection
+def DenseNet201(*args, **kwargs):
+ return densenet.DenseNet201(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return densenet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.densenet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return densenet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/imagenet_utils.py b/tensorflow/python/keras/applications/imagenet_utils.py
index 70f8f6fb32..c25b5c2bdd 100644
--- a/tensorflow/python/keras/applications/imagenet_utils.py
+++ b/tensorflow/python/keras/applications/imagenet_utils.py
@@ -19,27 +19,18 @@ from __future__ import division
from __future__ import print_function
from keras_applications import imagenet_utils
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-decode_predictions = imagenet_utils.decode_predictions
-preprocess_input = imagenet_utils.preprocess_input
-tf_export(
- 'keras.applications.imagenet_utils.decode_predictions',
- 'keras.applications.densenet.decode_predictions',
- 'keras.applications.inception_resnet_v2.decode_predictions',
- 'keras.applications.inception_v3.decode_predictions',
- 'keras.applications.mobilenet.decode_predictions',
- 'keras.applications.mobilenet_v2.decode_predictions',
- 'keras.applications.nasnet.decode_predictions',
- 'keras.applications.resnet50.decode_predictions',
- 'keras.applications.vgg16.decode_predictions',
- 'keras.applications.vgg19.decode_predictions',
- 'keras.applications.xception.decode_predictions',
-)(decode_predictions)
-tf_export(
- 'keras.applications.imagenet_utils.preprocess_input',
- 'keras.applications.resnet50.preprocess_input',
- 'keras.applications.vgg16.preprocess_input',
- 'keras.applications.vgg19.preprocess_input',
-)(preprocess_input)
+@tf_export('keras.applications.imagenet_utils.preprocess_input')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return imagenet_utils.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.imagenet_utils.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return imagenet_utils.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/applications/inception_resnet_v2.py
index 63debb4e0d..0b9ef371fa 100644
--- a/tensorflow/python/keras/applications/inception_resnet_v2.py
+++ b/tensorflow/python/keras/applications/inception_resnet_v2.py
@@ -20,13 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import inception_resnet_v2
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-InceptionResNetV2 = inception_resnet_v2.InceptionResNetV2
-decode_predictions = inception_resnet_v2.decode_predictions
-preprocess_input = inception_resnet_v2.preprocess_input
-tf_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
- 'keras.applications.InceptionResNetV2')(InceptionResNetV2)
-tf_export(
- 'keras.applications.inception_resnet_v2.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.inception_resnet_v2.InceptionResNetV2',
+ 'keras.applications.InceptionResNetV2')
+@keras_modules_injection
+def InceptionResNetV2(*args, **kwargs):
+ return inception_resnet_v2.InceptionResNetV2(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_resnet_v2.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return inception_resnet_v2.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_resnet_v2.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return inception_resnet_v2.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/inception_v3.py b/tensorflow/python/keras/applications/inception_v3.py
index 87534086c8..ab76826e17 100644
--- a/tensorflow/python/keras/applications/inception_v3.py
+++ b/tensorflow/python/keras/applications/inception_v3.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import inception_v3
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-InceptionV3 = inception_v3.InceptionV3
-decode_predictions = inception_v3.decode_predictions
-preprocess_input = inception_v3.preprocess_input
-tf_export('keras.applications.inception_v3.InceptionV3',
- 'keras.applications.InceptionV3')(InceptionV3)
-tf_export('keras.applications.inception_v3.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.inception_v3.InceptionV3',
+ 'keras.applications.InceptionV3')
+@keras_modules_injection
+def InceptionV3(*args, **kwargs):
+ return inception_v3.InceptionV3(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_v3.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return inception_v3.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.inception_v3.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return inception_v3.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/mobilenet.py b/tensorflow/python/keras/applications/mobilenet.py
index 3528f027b3..1f71a5ae99 100644
--- a/tensorflow/python/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/applications/mobilenet.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import mobilenet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-MobileNet = mobilenet.MobileNet
-decode_predictions = mobilenet.decode_predictions
-preprocess_input = mobilenet.preprocess_input
-tf_export('keras.applications.mobilenet.MobileNet',
- 'keras.applications.MobileNet')(MobileNet)
-tf_export('keras.applications.mobilenet.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.mobilenet.MobileNet',
+ 'keras.applications.MobileNet')
+@keras_modules_injection
+def MobileNet(*args, **kwargs):
+ return mobilenet.MobileNet(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return mobilenet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return mobilenet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/mobilenet_v2.py b/tensorflow/python/keras/applications/mobilenet_v2.py
index 9194c3ee14..52ac5959ad 100644
--- a/tensorflow/python/keras/applications/mobilenet_v2.py
+++ b/tensorflow/python/keras/applications/mobilenet_v2.py
@@ -19,4 +19,26 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(fchollet): export MobileNetV2 as part of the public API in next version.
+from keras_applications import mobilenet_v2
+
+from tensorflow.python.keras.applications import keras_modules_injection
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export('keras.applications.mobilenet_v2.MobileNetV2',
+ 'keras.applications.MobileNetV2')
+@keras_modules_injection
+def MobileNetV2(*args, **kwargs):
+ return mobilenet_v2.MobileNetV2(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet_v2.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return mobilenet_v2.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.mobilenet_v2.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return mobilenet_v2.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/nasnet.py b/tensorflow/python/keras/applications/nasnet.py
index 26ff5db53f..44fc329d57 100644
--- a/tensorflow/python/keras/applications/nasnet.py
+++ b/tensorflow/python/keras/applications/nasnet.py
@@ -20,15 +20,32 @@ from __future__ import division
from __future__ import print_function
from keras_applications import nasnet
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-NASNetMobile = nasnet.NASNetMobile
-NASNetLarge = nasnet.NASNetLarge
-decode_predictions = nasnet.decode_predictions
-preprocess_input = nasnet.preprocess_input
-tf_export('keras.applications.nasnet.NASNetMobile',
- 'keras.applications.NASNetMobile')(NASNetMobile)
-tf_export('keras.applications.nasnet.NASNetLarge',
- 'keras.applications.NASNetLarge')(NASNetLarge)
-tf_export('keras.applications.nasnet.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.nasnet.NASNetMobile',
+ 'keras.applications.NASNetMobile')
+@keras_modules_injection
+def NASNetMobile(*args, **kwargs):
+ return nasnet.NASNetMobile(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.NASNetLarge',
+ 'keras.applications.NASNetLarge')
+@keras_modules_injection
+def NASNetLarge(*args, **kwargs):
+ return nasnet.NASNetLarge(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return nasnet.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.nasnet.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return nasnet.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/resnet50.py b/tensorflow/python/keras/applications/resnet50.py
index 4d804a3c44..80d3f9044f 100644
--- a/tensorflow/python/keras/applications/resnet50.py
+++ b/tensorflow/python/keras/applications/resnet50.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import resnet50
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-ResNet50 = resnet50.ResNet50
-decode_predictions = resnet50.decode_predictions
-preprocess_input = resnet50.preprocess_input
-tf_export('keras.applications.resnet50.ResNet50',
- 'keras.applications.ResNet50')(ResNet50)
+@tf_export('keras.applications.resnet50.ResNet50',
+ 'keras.applications.ResNet50')
+@keras_modules_injection
+def ResNet50(*args, **kwargs):
+ return resnet50.ResNet50(*args, **kwargs)
+
+
+@tf_export('keras.applications.resnet50.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return resnet50.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.resnet50.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return resnet50.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/vgg16.py b/tensorflow/python/keras/applications/vgg16.py
index c420d9b81e..8557d26931 100644
--- a/tensorflow/python/keras/applications/vgg16.py
+++ b/tensorflow/python/keras/applications/vgg16.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import vgg16
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-VGG16 = vgg16.VGG16
-decode_predictions = vgg16.decode_predictions
-preprocess_input = vgg16.preprocess_input
-tf_export('keras.applications.vgg16.VGG16',
- 'keras.applications.VGG16')(VGG16)
+@tf_export('keras.applications.vgg16.VGG16',
+ 'keras.applications.VGG16')
+@keras_modules_injection
+def VGG16(*args, **kwargs):
+ return vgg16.VGG16(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg16.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return vgg16.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg16.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return vgg16.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/vgg19.py b/tensorflow/python/keras/applications/vgg19.py
index 73d3d1d1c3..8fc04413a0 100644
--- a/tensorflow/python/keras/applications/vgg19.py
+++ b/tensorflow/python/keras/applications/vgg19.py
@@ -20,11 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import vgg19
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-VGG19 = vgg19.VGG19
-decode_predictions = vgg19.decode_predictions
-preprocess_input = vgg19.preprocess_input
-tf_export('keras.applications.vgg19.VGG19',
- 'keras.applications.VGG19')(VGG19)
+@tf_export('keras.applications.vgg19.VGG19',
+ 'keras.applications.VGG19')
+@keras_modules_injection
+def VGG19(*args, **kwargs):
+ return vgg19.VGG19(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg19.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return vgg19.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.vgg19.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return vgg19.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/applications/xception.py b/tensorflow/python/keras/applications/xception.py
index 5b221ac8e0..960e6dec69 100644
--- a/tensorflow/python/keras/applications/xception.py
+++ b/tensorflow/python/keras/applications/xception.py
@@ -20,12 +20,25 @@ from __future__ import division
from __future__ import print_function
from keras_applications import xception
+
+from tensorflow.python.keras.applications import keras_modules_injection
from tensorflow.python.util.tf_export import tf_export
-Xception = xception.Xception
-decode_predictions = xception.decode_predictions
-preprocess_input = xception.preprocess_input
-tf_export('keras.applications.xception.Xception',
- 'keras.applications.Xception')(Xception)
-tf_export('keras.applications.xception.preprocess_input')(preprocess_input)
+@tf_export('keras.applications.xception.Xception',
+ 'keras.applications.Xception')
+@keras_modules_injection
+def Xception(*args, **kwargs):
+ return xception.Xception(*args, **kwargs)
+
+
+@tf_export('keras.applications.xception.decode_predictions')
+@keras_modules_injection
+def decode_predictions(*args, **kwargs):
+ return xception.decode_predictions(*args, **kwargs)
+
+
+@tf_export('keras.applications.xception.preprocess_input')
+@keras_modules_injection
+def preprocess_input(*args, **kwargs):
+ return xception.preprocess_input(*args, **kwargs)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 62433a400b..b52ab7f05c 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -696,10 +696,9 @@ def _get_variables(graph=None):
return variables
-def _initialize_variables(session, variables=None):
+def _initialize_variables(session):
"""Utility to initialize uninitialized variables on the fly."""
- if variables is None:
- variables = _get_variables(ops.get_default_graph())
+ variables = _get_variables(ops.get_default_graph())
candidate_vars = []
for v in variables:
if not getattr(v, '_keras_initialized', False):
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index a63267a5dd..266af56611 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -119,7 +119,7 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(keras.backend.get_uid('foo'), 1)
def test_learning_phase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
keras.backend.set_learning_phase(1)
self.assertEqual(keras.backend.learning_phase(), 1)
with self.assertRaises(ValueError):
@@ -133,7 +133,7 @@ class BackendUtilsTest(test.TestCase):
sess.run(y, feed_dict={x: np.random.random((2, 3))})
def test_learning_phase_scope(self):
- with self.test_session():
+ with self.cached_session():
initial_learning_phase = keras.backend.learning_phase()
with keras.backend.learning_phase_scope(1) as lp:
self.assertEqual(lp, 1)
@@ -156,7 +156,7 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(keras.backend.int_shape(x), (None, 4))
def test_in_train_phase(self):
- with self.test_session():
+ with self.cached_session():
y1 = keras.backend.variable(1)
y2 = keras.backend.variable(2)
y = keras.backend.in_train_phase(y1, y2)
@@ -194,7 +194,7 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(y.op.name[:12], 'StopGradient')
def test_function_tf_feed_symbols(self):
- with self.test_session():
+ with self.cached_session():
# Test feeding a resource variable to `function`.
x1 = keras.backend.placeholder(shape=())
x2 = keras.backend.placeholder(shape=())
@@ -232,7 +232,7 @@ class BackendUtilsTest(test.TestCase):
# keras.backend.function() these do not have control dependency on `outputs`
# so they can run in parallel. Also they should not contribute to output of
# keras.backend.function().
- with self.test_session():
+ with self.cached_session():
x = keras.backend.variable(0.)
y = keras.backend.variable(0.)
x_placeholder = keras.backend.placeholder(shape=())
@@ -253,7 +253,7 @@ class BackendUtilsTest(test.TestCase):
# constructor but we can modify the values in the dictionary. Through
# this feed_dict we can provide additional substitutions besides Keras
# inputs.
- with self.test_session():
+ with self.cached_session():
x = keras.backend.variable(0.)
y = keras.backend.variable(0.)
x_placeholder = keras.backend.placeholder(shape=())
@@ -313,7 +313,7 @@ class BackendUtilsTest(test.TestCase):
self.times_called += 1
self.callback_result = result
- with self.test_session():
+ with self.cached_session():
callback = CallbackStub()
x_placeholder = keras.backend.placeholder(shape=())
y_placeholder = keras.backend.placeholder(shape=())
@@ -335,39 +335,39 @@ class BackendUtilsTest(test.TestCase):
class BackendVariableTest(test.TestCase):
def test_zeros(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((3, 4))
val = keras.backend.eval(x)
self.assertAllClose(val, np.zeros((3, 4)))
def test_ones(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.ones((3, 4))
val = keras.backend.eval(x)
self.assertAllClose(val, np.ones((3, 4)))
def test_eye(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.eye(4)
val = keras.backend.eval(x)
self.assertAllClose(val, np.eye(4))
def test_zeros_like(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((3, 4))
y = keras.backend.zeros_like(x)
val = keras.backend.eval(y)
self.assertAllClose(val, np.zeros((3, 4)))
def test_ones_like(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((3, 4))
y = keras.backend.ones_like(x)
val = keras.backend.eval(y)
self.assertAllClose(val, np.ones((3, 4)))
def test_random_uniform_variable(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.random_uniform_variable((30, 20), low=1, high=2, seed=0)
val = keras.backend.eval(x)
self.assertAllClose(val.mean(), 1.5, atol=1e-1)
@@ -375,7 +375,7 @@ class BackendVariableTest(test.TestCase):
self.assertAllClose(val.min(), 1., atol=1e-1)
def test_random_normal_variable(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.random_normal_variable((30, 20), 1., 0.5,
seed=0)
val = keras.backend.eval(x)
@@ -383,20 +383,20 @@ class BackendVariableTest(test.TestCase):
self.assertAllClose(val.std(), 0.5, atol=1e-1)
def test_count_params(self):
- with self.test_session():
+ with self.cached_session():
x = keras.backend.zeros((4, 5))
val = keras.backend.count_params(x)
self.assertAllClose(val, 20)
def test_constant(self):
- with self.test_session():
+ with self.cached_session():
ref_val = np.random.random((3, 4)).astype('float32')
x = keras.backend.constant(ref_val)
val = keras.backend.eval(x)
self.assertAllClose(val, ref_val)
def test_sparse_variable(self):
- with self.test_session():
+ with self.cached_session():
val = scipy.sparse.eye(10)
x = keras.backend.variable(val)
self.assertTrue(isinstance(x, sparse_tensor.SparseTensor))
@@ -445,7 +445,7 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.argmax, np.argmax),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
keras_kwargs={'axis': 1},
np_kwargs={'axis': 1})
@@ -471,7 +471,7 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.exp, np.exp),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7))
ops_to_test = [
@@ -479,19 +479,19 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.log, np.log),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras_op, np_op,
input_shape=(4, 7),
negative_values=False)
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.clip, np.clip,
input_shape=(6, 4),
keras_kwargs={'min_value': 0.1, 'max_value': 2.4},
np_kwargs={'a_min': 0.1, 'a_max': 1.4})
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.pow, np.power,
input_shape=(6, 4),
@@ -510,14 +510,14 @@ class BackendLinearAlgebraTest(test.TestCase):
(keras.backend.minimum, np.minimum),
]
for keras_op, np_op in ops_to_test:
- with self.test_session():
+ with self.cached_session():
compare_two_inputs_op_to_numpy(keras_op, np_op,
input_shape_a=(4, 7),
input_shape_b=(4, 7))
def test_relu(self):
x = ops.convert_to_tensor([[-4, 0], [2, 7]], 'float32')
- with self.test_session():
+ with self.cached_session():
# standard relu
relu_op = keras.backend.relu(x)
self.assertAllClose(keras.backend.eval(relu_op), [[0, 0], [2, 7]])
@@ -579,7 +579,7 @@ class BackendLinearAlgebraTest(test.TestCase):
class BackendShapeOpsTest(test.TestCase):
def test_reshape(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.reshape, np.reshape,
input_shape=(4, 7),
keras_args=[(2, 14)],
@@ -592,7 +592,7 @@ class BackendShapeOpsTest(test.TestCase):
self.assertEqual(y.get_shape().as_list(), [1, 2, 5])
def test_permute_dimensions(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.permute_dimensions,
np.transpose,
input_shape=(4, 7),
@@ -671,14 +671,14 @@ class BackendShapeOpsTest(test.TestCase):
self.assertEqual(y.get_shape().as_list(), [1, 2, 3])
def test_flatten(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.flatten,
np.reshape,
input_shape=(4, 7, 6),
np_args=[(4 * 7 * 6,)])
def test_batch_flatten(self):
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.batch_flatten,
np.reshape,
input_shape=(4, 7, 6),
@@ -693,7 +693,7 @@ class BackendShapeOpsTest(test.TestCase):
y[:, padding[0]:-padding[1], :] = x
return y
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(keras.backend.temporal_padding,
ref_op,
input_shape=(4, 7, 6),
@@ -716,7 +716,7 @@ class BackendShapeOpsTest(test.TestCase):
y[:, :, padding[0][0]:-padding[0][1], padding[1][0]:-padding[1][1]] = x
return y
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.spatial_2d_padding,
ref_op,
@@ -759,7 +759,7 @@ class BackendShapeOpsTest(test.TestCase):
padding[2][0]:-padding[2][1]] = x
return y
- with self.test_session():
+ with self.cached_session():
compare_single_input_op_to_numpy(
keras.backend.spatial_3d_padding,
ref_op,
@@ -781,7 +781,7 @@ class BackendShapeOpsTest(test.TestCase):
class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
def test_bias_add(self):
- with self.test_session():
+ with self.cached_session():
keras_op = keras.backend.bias_add
np_op = np.add
compare_two_inputs_op_to_numpy(keras_op, np_op,
@@ -807,7 +807,8 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
keras.backend.bias_add(x, b, data_format='unknown')
def test_bias_add_channels_first(self):
- with self.test_session():
+ with self.cached_session():
+
def keras_op(x, b):
return keras.backend.bias_add(x, b, data_format='channels_first')
@@ -983,7 +984,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
strides,
output_shape,
'channels_last')
- with self.test_session():
+ with self.cached_session():
conv_cf = keras.backend.eval(conv_cf)
conv_cl = keras.backend.eval(conv_cl)
@@ -1033,7 +1034,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
output_shape,
'channels_last')
- with self.test_session():
+ with self.cached_session():
local_conv = keras.backend.eval(local_conv)
local_conv_dim = keras.backend.eval(local_conv_dim)
@@ -1191,7 +1192,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
{'go_backwards': False, 'mask': mask},
{'go_backwards': False, 'mask': mask, 'unroll': True},
]
- with self.test_session():
+ with self.cached_session():
for i, kwargs in enumerate(kwargs_list):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
@@ -1287,7 +1288,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
{'go_backwards': False, 'mask': mask},
{'go_backwards': False, 'mask': mask, 'unroll': True},
]
- with self.test_session():
+ with self.cached_session():
for i, kwargs in enumerate(kwargs_list):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
@@ -1383,7 +1384,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
class TestCTC(test.TestCase):
def test_ctc_decode(self):
- with self.test_session():
+ with self.cached_session():
depth = 6
seq_len_0 = 5
input_prob_matrix_0 = np.asarray(
@@ -1408,8 +1409,8 @@ class TestCTC(test.TestCase):
np.array([seq_len_0], dtype=np.int32))
# batch_size length vector of negative log probabilities
log_prob_truth = np.array([
- 0.584855, # output beam 0
- 0.389139 # output beam 1
+ -3.5821197, # output beam 0
+ -3.777835 # output beam 1
], np.float32)[np.newaxis, :]
decode_truth = [np.array([1, 0]), np.array([0, 1, 0])]
@@ -1432,7 +1433,7 @@ class TestCTC(test.TestCase):
self.assertAllClose(log_prob_truth, log_prob_pred)
def test_ctc_batch_cost(self):
- with self.test_session():
+ with self.cached_session():
label_lens = np.expand_dims(np.asarray([5, 4]), 1)
input_lens = np.expand_dims(np.asarray([5, 5]), 1) # number of timesteps
loss_log_probs = [3.34211, 5.42262]
@@ -1488,13 +1489,13 @@ class TestCTC(test.TestCase):
class TestRandomOps(test.TestCase):
def test_random_binomial(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(123)
x = keras.backend.random_binomial((1000, 1000), p=0.5)
self.assertAllClose(np.mean(keras.backend.eval(x)), 0.5, atol=0.1)
def test_truncated_normal(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(123)
x = keras.backend.truncated_normal((1000, 1000), mean=0.0, stddev=1.0)
y = keras.backend.eval(x)
diff --git a/tensorflow/python/keras/constraints_test.py b/tensorflow/python/keras/constraints_test.py
index 84e2db1033..4f674ea7c5 100644
--- a/tensorflow/python/keras/constraints_test.py
+++ b/tensorflow/python/keras/constraints_test.py
@@ -49,7 +49,7 @@ class KerasConstraintsTest(test.TestCase):
assert fn.__class__ == ref_fn.__class__
def test_max_norm(self):
- with self.test_session():
+ with self.cached_session():
array = get_example_array()
for m in get_test_values():
norm_instance = keras.constraints.max_norm(m)
@@ -69,13 +69,13 @@ class KerasConstraintsTest(test.TestCase):
self.assertAllClose(x_normed_actual, x_normed_target, rtol=1e-05)
def test_non_neg(self):
- with self.test_session():
+ with self.cached_session():
non_neg_instance = keras.constraints.non_neg()
normed = non_neg_instance(keras.backend.variable(get_example_array()))
assert np.all(np.min(keras.backend.eval(normed), axis=1) == 0.)
def test_unit_norm(self):
- with self.test_session():
+ with self.cached_session():
unit_norm_instance = keras.constraints.unit_norm()
normalized = unit_norm_instance(
keras.backend.variable(get_example_array()))
@@ -87,7 +87,7 @@ class KerasConstraintsTest(test.TestCase):
assert np.abs(largest_difference) < 10e-5
def test_min_max_norm(self):
- with self.test_session():
+ with self.cached_session():
array = get_example_array()
for m in get_test_values():
norm_instance = keras.constraints.min_max_norm(min_value=m,
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index aca8eac3a5..b6b05c0311 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -42,7 +42,6 @@ from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint:
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import function_utils
@@ -50,6 +49,7 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+from tensorflow.tools.docs import doc_controls
class CallConvention(enum.Enum):
@@ -79,6 +79,7 @@ class Layer(checkpointable.CheckpointableBase):
Users will just instantiate a layer and then treat it as a callable.
We recommend that descendants of `Layer` implement the following methods:
+
* `__init__()`: Save configuration in member variables
* `build()`: Called once from `__call__`, when we know the shapes of inputs
and `dtype`. Should have the calls to `add_weight()`, and then
@@ -272,6 +273,7 @@ class Layer(checkpointable.CheckpointableBase):
return []
return self._updates
+ @doc_controls.for_subclass_implementers
def add_update(self, updates, inputs=None):
"""Add update op(s), potentially dependent on layer inputs.
@@ -372,6 +374,7 @@ class Layer(checkpointable.CheckpointableBase):
else:
return self._losses
+ @doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
"""Add loss tensor(s), potentially dependent on layer inputs.
@@ -463,10 +466,12 @@ class Layer(checkpointable.CheckpointableBase):
"""Creates the variables of the layer."""
self.built = True
+ @doc_controls.for_subclass_implementers
def add_variable(self, *args, **kwargs):
"""Alias for `add_weight`."""
return self.add_weight(*args, **kwargs)
+ @doc_controls.for_subclass_implementers
def add_weight(self,
name,
shape,
@@ -477,8 +482,8 @@ class Layer(checkpointable.CheckpointableBase):
constraint=None,
partitioner=None,
use_resource=None,
- synchronization=vs.VariableSynchronization.AUTO,
- aggregation=vs.VariableAggregation.NONE,
+ synchronization=tf_variables.VariableSynchronization.AUTO,
+ aggregation=tf_variables.VariableAggregation.NONE,
**kwargs):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -535,7 +540,7 @@ class Layer(checkpointable.CheckpointableBase):
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
- if synchronization == vs.VariableSynchronization.ON_READ:
+ if synchronization == tf_variables.VariableSynchronization.ON_READ:
if trainable:
raise ValueError(
'Synchronization value can be set to '
@@ -656,6 +661,7 @@ class Layer(checkpointable.CheckpointableBase):
activity_regularization = self._activity_regularizer(output)
self.add_loss(activity_regularization, inputs=inputs)
+ @doc_controls.for_subclass_implementers
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""This is where the layer's logic lives.
@@ -1422,11 +1428,13 @@ class Layer(checkpointable.CheckpointableBase):
'instead.' % self.name)
@property
+ @doc_controls.do_not_doc_inheritable
def inbound_nodes(self):
"""Deprecated, do NOT use! Only for compatibility with external Keras."""
return self._inbound_nodes
@property
+ @doc_controls.do_not_doc_inheritable
def outbound_nodes(self):
"""Deprecated, do NOT use! Only for compatibility with external Keras."""
return self._outbound_nodes
@@ -1897,8 +1905,8 @@ def make_variable(name,
constraint=None,
use_resource=None,
collections=None,
- synchronization=vs.VariableSynchronization.AUTO,
- aggregation=vs.VariableAggregation.NONE,
+ synchronization=tf_variables.VariableSynchronization.AUTO,
+ aggregation=tf_variables.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
@@ -1926,8 +1934,8 @@ def make_variable(name,
then this parameter is ignored and any added variables are also
marked as non-trainable. `trainable` defaults to `True` unless
`synchronization` is set to `ON_READ`.
- caching_device: Passed to `vs.variable`.
- validate_shape: Passed to `vs.variable`.
+ caching_device: Passed to `tf.Variable`.
+ validate_shape: Passed to `tf.Variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
collections: List of graph collections keys. The new variable is added to
@@ -1964,7 +1972,7 @@ def make_variable(name,
if use_resource is None:
use_resource = True
- v = vs.variable(
+ v = tf_variables.Variable(
initial_value=init_val,
name=name,
trainable=trainable,
diff --git a/tensorflow/python/keras/engine/sequential.py b/tensorflow/python/keras/engine/sequential.py
index cf6fb44275..9f4019e29c 100644
--- a/tensorflow/python/keras/engine/sequential.py
+++ b/tensorflow/python/keras/engine/sequential.py
@@ -332,6 +332,7 @@ class Sequential(Model):
else:
name = None
build_input_shape = None
+ layer_configs = config
model = cls(name=name)
for layer_config in layer_configs:
layer = layer_module.deserialize(layer_config,
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 502635c408..85d25411b4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -1727,6 +1727,13 @@ class Model(Network):
if batch_size is None and steps is None:
batch_size = 32
+ # Turn off prefetching since this is currently not deterministic. Once
+ # b/112498930 is fixed we can turn it back on.
+ # `_prefetch_on_device` is currently a property of only `MirroredStrategy`.
+ if (self._distribution_strategy and
+ hasattr(self._distribution_strategy, '_prefetch_on_device')):
+ self._distribution_strategy._prefetch_on_device = False # pylint: disable=protected-access
+
# Validate and standardize user data.
x, _, _ = self._standardize_user_data(
x, check_steps=True, steps_name='steps', steps=steps)
@@ -1735,8 +1742,12 @@ class Model(Network):
return training_eager.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
elif self._distribution_strategy:
- return training_distributed.predict_loop(
+ results = training_distributed.predict_loop(
self, x, verbose=verbose, steps=steps)
+ # Turn prefetching back on since we turned it off previously.
+ if hasattr(self._distribution_strategy, '_prefetch_on_device'):
+ self._distribution_strategy._prefetch_on_device = True # pylint: disable=protected-access
+ return results
else:
return training_arrays.predict_loop(
self, x, batch_size=batch_size, verbose=verbose, steps=steps)
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index bf2d231861..bf5c7fd7f8 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -36,7 +36,6 @@ from tensorflow.python.keras.engine.training_utils import weighted_masked_object
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@@ -392,7 +391,8 @@ class TrainingTest(test.TestCase):
def test_compile_with_sparse_placeholders(self):
with self.test_session():
input_layer = keras.layers.Input(shape=(10,), sparse=True)
- weights = variable_scope.get_variable(name='weights', shape=(10, 1))
+ weights = variables_lib.Variable(
+ np.ones((10, 1)).astype(np.float32), name='weights')
weights_mult = lambda x: sparse_ops.sparse_tensor_dense_matmul(x, weights)
output_layer = keras.layers.Lambda(weights_mult)(input_layer)
model = keras.Model([input_layer], output_layer)
diff --git a/tensorflow/python/keras/initializers_test.py b/tensorflow/python/keras/initializers_test.py
index 51725e03f2..8ddc9a17bf 100644
--- a/tensorflow/python/keras/initializers_test.py
+++ b/tensorflow/python/keras/initializers_test.py
@@ -40,7 +40,7 @@ class KerasInitializersTest(test.TestCase):
def test_uniform(self):
tensor_shape = (9, 6, 7)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.RandomUniform(minval=-1,
maxval=1,
seed=124),
@@ -49,14 +49,14 @@ class KerasInitializersTest(test.TestCase):
def test_normal(self):
tensor_shape = (8, 12, 99)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.RandomNormal(mean=0, stddev=1, seed=153),
tensor_shape,
target_mean=0., target_std=1)
def test_truncated_normal(self):
tensor_shape = (12, 99, 7)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.TruncatedNormal(mean=0,
stddev=1,
seed=126),
@@ -65,13 +65,13 @@ class KerasInitializersTest(test.TestCase):
def test_constant(self):
tensor_shape = (5, 6, 4)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.Constant(2), tensor_shape,
target_mean=2, target_max=2, target_min=2)
def test_lecun_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(keras.initializers.lecun_uniform(seed=123), tensor_shape,
@@ -79,7 +79,7 @@ class KerasInitializersTest(test.TestCase):
def test_glorot_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out))
self._runner(keras.initializers.glorot_uniform(seed=123), tensor_shape,
@@ -87,7 +87,7 @@ class KerasInitializersTest(test.TestCase):
def test_he_uniform(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(keras.initializers.he_uniform(seed=123), tensor_shape,
@@ -95,7 +95,7 @@ class KerasInitializersTest(test.TestCase):
def test_lecun_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(1. / fan_in)
self._runner(keras.initializers.lecun_normal(seed=123), tensor_shape,
@@ -103,7 +103,7 @@ class KerasInitializersTest(test.TestCase):
def test_glorot_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, fan_out = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / (fan_in + fan_out))
self._runner(keras.initializers.glorot_normal(seed=123), tensor_shape,
@@ -111,7 +111,7 @@ class KerasInitializersTest(test.TestCase):
def test_he_normal(self):
tensor_shape = (5, 6, 4, 2)
- with self.test_session():
+ with self.cached_session():
fan_in, _ = init_ops._compute_fans(tensor_shape)
std = np.sqrt(2. / fan_in)
self._runner(keras.initializers.he_normal(seed=123), tensor_shape,
@@ -119,12 +119,12 @@ class KerasInitializersTest(test.TestCase):
def test_orthogonal(self):
tensor_shape = (20, 20)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.orthogonal(seed=123), tensor_shape,
target_mean=0.)
def test_identity(self):
- with self.test_session():
+ with self.cached_session():
tensor_shape = (3, 4, 5)
with self.assertRaises(ValueError):
self._runner(keras.initializers.identity(), tensor_shape,
@@ -136,13 +136,13 @@ class KerasInitializersTest(test.TestCase):
def test_zero(self):
tensor_shape = (4, 5)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.zeros(), tensor_shape,
target_mean=0., target_max=0.)
def test_one(self):
tensor_shape = (4, 5)
- with self.test_session():
+ with self.cached_session():
self._runner(keras.initializers.ones(), tensor_shape,
target_mean=1., target_max=1.)
diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py
index a103b9fbf2..3c0f73b1c3 100644
--- a/tensorflow/python/keras/integration_test.py
+++ b/tensorflow/python/keras/integration_test.py
@@ -35,7 +35,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertTrue(keras.__version__.endswith('-tf'))
def test_vector_classification_sequential(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -60,7 +60,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_vector_classification_functional(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -84,7 +84,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_temporal_classification_sequential(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -106,7 +106,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_temporal_classification_sequential_tf_rnn(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -130,7 +130,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_image_classification_sequential(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -164,7 +164,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_video_classification_functional(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -194,7 +194,7 @@ class KerasIntegrationTest(test.TestCase):
def test_vector_classification_shared_sequential(self):
# Test that Sequential models that feature internal updates
# and internal losses can be shared.
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -228,7 +228,7 @@ class KerasIntegrationTest(test.TestCase):
def test_vector_classification_shared_model(self):
# Test that functional models that feature internal updates
# and internal losses can be shared.
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -259,14 +259,14 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_embedding_with_clipnorm(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Embedding(input_dim=1, output_dim=1))
model.compile(optimizer=keras.optimizers.SGD(clipnorm=0.1), loss='mse')
model.fit(np.array([[0]]), np.array([[[0.5]]]), epochs=1)
def test_using_tf_layers_in_keras_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
@@ -289,7 +289,7 @@ class KerasIntegrationTest(test.TestCase):
self.assertGreater(history.history['val_acc'][-1], 0.7)
def test_using_tf_layers_in_keras_functional_model(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
(x_train, y_train), _ = testing_utils.get_test_data(
train_samples=100,
diff --git a/tensorflow/python/keras/layers/advanced_activations_test.py b/tensorflow/python/keras/layers/advanced_activations_test.py
index 53c1baa2bb..b020b6e730 100644
--- a/tensorflow/python/keras/layers/advanced_activations_test.py
+++ b/tensorflow/python/keras/layers/advanced_activations_test.py
@@ -26,44 +26,44 @@ from tensorflow.python.platform import test
class AdvancedActivationsTest(test.TestCase):
def test_leaky_relu(self):
- with self.test_session():
+ with self.cached_session():
for alpha in [0., .5, -1.]:
testing_utils.layer_test(keras.layers.LeakyReLU,
kwargs={'alpha': alpha},
input_shape=(2, 3, 4))
def test_prelu(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.PReLU, kwargs={},
input_shape=(2, 3, 4))
def test_prelu_share(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.PReLU,
kwargs={'shared_axes': 1},
input_shape=(2, 3, 4))
def test_elu(self):
- with self.test_session():
+ with self.cached_session():
for alpha in [0., .5, -1.]:
testing_utils.layer_test(keras.layers.ELU,
kwargs={'alpha': alpha},
input_shape=(2, 3, 4))
def test_thresholded_relu(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.ThresholdedReLU,
kwargs={'theta': 0.5},
input_shape=(2, 3, 4))
def test_softmax(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.Softmax,
kwargs={'axis': 1},
input_shape=(2, 3, 4))
def test_relu(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': 10},
input_shape=(2, 3, 4))
@@ -71,14 +71,14 @@ class AdvancedActivationsTest(test.TestCase):
def test_relu_with_invalid_arg(self):
with self.assertRaisesRegexp(
ValueError, 'max_value of Relu layer cannot be negative value: -10'):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(keras.layers.ReLU,
kwargs={'max_value': -10},
input_shape=(2, 3, 4))
with self.assertRaisesRegexp(
ValueError,
'negative_slope of Relu layer cannot be negative value: -2'):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.ReLU,
kwargs={'negative_slope': -2},
diff --git a/tensorflow/python/keras/layers/convolutional_recurrent_test.py b/tensorflow/python/keras/layers/convolutional_recurrent_test.py
index 4b8f6f2a14..4a75793884 100644
--- a/tensorflow/python/keras/layers/convolutional_recurrent_test.py
+++ b/tensorflow/python/keras/layers/convolutional_recurrent_test.py
@@ -47,7 +47,7 @@ class ConvLSTMTest(test.TestCase):
input_channel)
for return_sequences in [True, False]:
- with self.test_session():
+ with self.cached_session():
# test for return state:
x = keras.Input(batch_shape=inputs.shape)
kwargs = {'data_format': data_format,
@@ -92,7 +92,7 @@ class ConvLSTMTest(test.TestCase):
input_num_row, input_num_col,
input_channel)
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
kwargs = {'data_format': 'channels_last',
'return_sequences': False,
@@ -144,7 +144,7 @@ class ConvLSTMTest(test.TestCase):
input_num_row, input_num_col,
input_channel)
- with self.test_session():
+ with self.cached_session():
kwargs = {'data_format': 'channels_last',
'return_sequences': False,
'kernel_size': (num_row, num_col),
@@ -168,7 +168,7 @@ class ConvLSTMTest(test.TestCase):
def test_conv_lstm_dropout(self):
# check dropout
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.ConvLSTM2D,
kwargs={'data_format': 'channels_last',
@@ -181,7 +181,7 @@ class ConvLSTMTest(test.TestCase):
input_shape=(1, 2, 5, 5, 2))
def test_conv_lstm_cloning(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.ConvLSTM2D(5, 3, input_shape=(None, 5, 5, 3)))
@@ -190,7 +190,7 @@ class ConvLSTMTest(test.TestCase):
weights = model.get_weights()
# Use a new graph to clone the model
- with self.test_session():
+ with self.cached_session():
clone = keras.models.clone_model(model)
clone.set_weights(weights)
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 49ca68ee9e..1df1d575b1 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -30,16 +30,16 @@ from tensorflow.python.platform import test
class CoreLayersTest(test.TestCase):
def test_masking(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Masking, kwargs={}, input_shape=(3, 2, 3))
def test_dropout(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Dropout, kwargs={'rate': 0.5}, input_shape=(3, 2))
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.Dropout,
kwargs={'rate': 0.5,
@@ -47,7 +47,7 @@ class CoreLayersTest(test.TestCase):
input_shape=(3, 2))
# https://github.com/tensorflow/tensorflow/issues/14819
- with self.test_session():
+ with self.cached_session():
dropout = keras.layers.Dropout(0.5)
self.assertEqual(True, dropout.supports_masking)
@@ -210,7 +210,7 @@ class CoreLayersTest(test.TestCase):
keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2))
def test_dense_regularization(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.Dense(
3,
kernel_regularizer=keras.regularizers.l1(0.01),
@@ -221,7 +221,7 @@ class CoreLayersTest(test.TestCase):
self.assertEqual(3, len(layer.losses))
def test_dense_constraints(self):
- with self.test_session():
+ with self.cached_session():
k_constraint = keras.constraints.max_norm(0.01)
b_constraint = keras.constraints.max_norm(0.01)
layer = keras.layers.Dense(
@@ -231,14 +231,14 @@ class CoreLayersTest(test.TestCase):
self.assertEqual(layer.bias.constraint, b_constraint)
def test_activity_regularization(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.ActivityRegularization(l1=0.1)
layer(keras.backend.variable(np.ones((2, 4))))
self.assertEqual(1, len(layer.losses))
_ = layer.get_config()
def test_lambda_output_shape(self):
- with self.test_session():
+ with self.cached_session():
l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual((1, 1), l.get_config()['output_shape'])
@@ -247,13 +247,13 @@ class CoreLayersTest(test.TestCase):
def get_output_shape(input_shape):
return 1 * input_shape
- with self.test_session():
+ with self.cached_session():
l = keras.layers.Lambda(lambda x: x + 1, output_shape=get_output_shape)
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual('lambda', l.get_config()['output_shape_type'])
def test_lambda_config_serialization(self):
- with self.test_session():
+ with self.cached_session():
# test serialization with output_shape and output_shape_type
layer = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
layer(keras.backend.variable(np.ones((1, 1))))
diff --git a/tensorflow/python/keras/layers/embeddings_test.py b/tensorflow/python/keras/layers/embeddings_test.py
index fff1c5ef98..cab176ee34 100644
--- a/tensorflow/python/keras/layers/embeddings_test.py
+++ b/tensorflow/python/keras/layers/embeddings_test.py
@@ -68,7 +68,7 @@ class EmbeddingTest(test.TestCase):
expected_output_dtype='float32')
def test_embedding_correctness(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.Embedding(output_dim=2, input_dim=2)
layer.build((None, 2))
matrix = np.array([[1, 1], [2, 2]])
diff --git a/tensorflow/python/keras/layers/local_test.py b/tensorflow/python/keras/layers/local_test.py
index 4781bcae07..8589b32b3c 100644
--- a/tensorflow/python/keras/layers/local_test.py
+++ b/tensorflow/python/keras/layers/local_test.py
@@ -87,7 +87,7 @@ class LocallyConnectedLayersTest(test.TestCase):
keras.layers.LocallyConnected1D,
**kwargs)
else:
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected1D(**kwargs)
layer.build((num_samples, num_steps, input_dim))
self.assertEqual(len(layer.losses), 2)
@@ -105,7 +105,7 @@ class LocallyConnectedLayersTest(test.TestCase):
'kernel_constraint': k_constraint,
'bias_constraint': b_constraint,
}
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected1D(**kwargs)
layer.build((num_samples, num_steps, input_dim))
self.assertEqual(layer.kernel.constraint, k_constraint)
@@ -197,7 +197,7 @@ class LocallyConnectedLayersTest(test.TestCase):
keras.layers.LocallyConnected2D,
**kwargs)
else:
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected2D(**kwargs)
layer.build((num_samples, num_row, num_col, stack_size))
self.assertEqual(len(layer.losses), 2)
@@ -214,7 +214,7 @@ class LocallyConnectedLayersTest(test.TestCase):
'kernel_constraint': k_constraint,
'bias_constraint': b_constraint,
}
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.LocallyConnected2D(**kwargs)
layer.build((num_samples, num_row, num_col, stack_size))
self.assertEqual(layer.kernel.constraint, k_constraint)
diff --git a/tensorflow/python/keras/layers/merge_test.py b/tensorflow/python/keras/layers/merge_test.py
index 39bc98d039..7bcfcaeddb 100644
--- a/tensorflow/python/keras/layers/merge_test.py
+++ b/tensorflow/python/keras/layers/merge_test.py
@@ -46,7 +46,7 @@ class MergeLayersTest(test.TestCase):
self.assertAllClose(out, x1 + x2 + x3, atol=1e-4)
def test_merge_add_masking(self):
- with self.test_session():
+ with self.cached_session():
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
m1 = keras.layers.Masking()(i1)
@@ -57,7 +57,7 @@ class MergeLayersTest(test.TestCase):
self.assertListEqual(mask.get_shape().as_list(), [None, 4])
def test_merge_add_dynamic_shape(self):
- with self.test_session():
+ with self.cached_session():
i1 = array_ops.placeholder(shape=(4, None), dtype='float32')
i2 = array_ops.placeholder(shape=(4, 5), dtype='float32')
layer = keras.layers.Add()
@@ -149,7 +149,7 @@ class MergeLayersTest(test.TestCase):
self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4)
def test_merge_concatenate_masking(self):
- with self.test_session():
+ with self.cached_session():
i1 = keras.layers.Input(shape=(4, 5))
i2 = keras.layers.Input(shape=(4, 5))
m1 = keras.layers.Masking()(i1)
diff --git a/tensorflow/python/keras/layers/noise_test.py b/tensorflow/python/keras/layers/noise_test.py
index aa2be62390..cea304680b 100644
--- a/tensorflow/python/keras/layers/noise_test.py
+++ b/tensorflow/python/keras/layers/noise_test.py
@@ -27,14 +27,14 @@ from tensorflow.python.platform import test
class NoiseLayersTest(test.TestCase):
def test_GaussianNoise(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.GaussianNoise,
kwargs={'stddev': 1.},
input_shape=(3, 2, 3))
def test_GaussianDropout(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.GaussianDropout,
kwargs={'rate': 0.5},
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index cd26e04c39..013d572088 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -34,7 +34,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -313,18 +313,18 @@ class BatchNormalization(Layer):
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
- synchronization=variable_scope.VariableSynchronization.ON_READ,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
trainable=False,
- aggregation=variable_scope.VariableAggregation.MEAN)
+ aggregation=tf_variables.VariableAggregation.MEAN)
self.moving_variance = self.add_weight(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_variance_initializer,
- synchronization=variable_scope.VariableSynchronization.ON_READ,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
trainable=False,
- aggregation=variable_scope.VariableAggregation.MEAN)
+ aggregation=tf_variables.VariableAggregation.MEAN)
if self.renorm:
# Create variables to maintain the moving mean and standard deviation.
@@ -340,9 +340,9 @@ class BatchNormalization(Layer):
shape=shape,
dtype=param_dtype,
initializer=init_ops.zeros_initializer(),
- synchronization=variable_scope.VariableSynchronization.ON_READ,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
trainable=False,
- aggregation=variable_scope.VariableAggregation.MEAN)
+ aggregation=tf_variables.VariableAggregation.MEAN)
return var
with distribution_strategy_context.get_distribution_strategy(
diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py
index a97b4cac46..2844b84799 100644
--- a/tensorflow/python/keras/layers/normalization_test.py
+++ b/tensorflow/python/keras/layers/normalization_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class NormalizationLayersTest(test.TestCase):
def test_basic_batchnorm(self):
- with self.test_session():
+ with self.cached_session():
testing_utils.layer_test(
keras.layers.BatchNormalization,
kwargs={
@@ -54,7 +54,7 @@ class NormalizationLayersTest(test.TestCase):
input_shape=(3, 3))
def test_batchnorm_weights(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.BatchNormalization(scale=False, center=False)
layer.build((None, 3, 4))
self.assertEqual(len(layer.trainable_weights), 0)
@@ -66,7 +66,7 @@ class NormalizationLayersTest(test.TestCase):
self.assertEqual(len(layer.weights), 4)
def test_batchnorm_regularization(self):
- with self.test_session():
+ with self.cached_session():
layer = keras.layers.BatchNormalization(
gamma_regularizer='l1', beta_regularizer='l1')
layer.build((None, 3, 4))
@@ -79,7 +79,7 @@ class NormalizationLayersTest(test.TestCase):
self.assertEqual(layer.beta.constraint, max_norm)
def test_batchnorm_correctness(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
@@ -96,7 +96,7 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
def test_batchnorm_mixed_precision(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
@@ -133,7 +133,7 @@ class NormalizationLayersTest(test.TestCase):
np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1)
def test_batchnorm_convnet_channel_last(self):
- with self.test_session():
+ with self.cached_session():
# keras.backend.set_learning_phase(True)
model = keras.models.Sequential()
@@ -155,7 +155,7 @@ class NormalizationLayersTest(test.TestCase):
def test_shared_batchnorm(self):
"""Test that a BN layer can be shared across different data streams.
"""
- with self.test_session():
+ with self.cached_session():
# Test single layer reuse
bn = keras.layers.BatchNormalization()
x1 = keras.layers.Input(shape=(10,))
@@ -187,7 +187,7 @@ class NormalizationLayersTest(test.TestCase):
new_model.train_on_batch(x, x)
def test_that_trainable_disables_updates(self):
- with self.test_session():
+ with self.cached_session():
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
@@ -230,7 +230,7 @@ class NormalizationLayersTest(test.TestCase):
Computes mean and std for current inputs then
applies batch normalization using them.
"""
- with self.test_session():
+ with self.cached_session():
bn_mean = 0.5
bn_std = 10.
val_a = np.expand_dims(np.arange(10.), axis=1)
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 65171acfb6..04b3aecff8 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -73,19 +73,27 @@ class StackedRNNCells(Layer):
'`state_size` attribute. '
'received cells:', cells)
self.cells = cells
+ # reverse_state_order determines whether the state size will be in a reverse
+ # order of the cells' state. User might want to set this to True to keep the
+ # existing behavior. This is only useful when use RNN(return_state=True)
+ # since the state will be returned as the same order of state_size.
+ self.reverse_state_order = kwargs.pop('reverse_state_order', False)
+ if self.reverse_state_order:
+ logging.warning('reverse_state_order=True in StackedRNNCells will soon '
+ 'be deprecated. Please update the code to work with the '
+ 'natural order of states if you reply on the RNN states, '
+ 'eg RNN(return_state=True).')
super(StackedRNNCells, self).__init__(**kwargs)
@property
def state_size(self):
- # States are a flat list
- # in reverse order of the cell stack.
- # This allows to preserve the requirement
- # `stack.state_size[0] == output_dim`.
- # e.g. states of a 2-layer LSTM would be
- # `[h2, c2, h1, c1]`
+ # States are a flat list of the individual cell state size.
+ # e.g. states of a 2-layer LSTM would be `[h1, c1, h2, c2]`.
# (assuming one LSTM has states [h, c])
+ # In the case of reverse_state_order=True, the state_size will be
+ # [h2, c2, h1, c1].
state_size = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
if _is_multiple_state(cell.state_size):
state_size += list(cell.state_size)
else:
@@ -96,15 +104,16 @@ class StackedRNNCells(Layer):
def output_size(self):
if getattr(self.cells[-1], 'output_size', None) is not None:
return self.cells[-1].output_size
+ elif _is_multiple_state(self.cells[-1].state_size):
+ return self.cells[-1].state_size[0]
else:
- return self.state_size[0]
+ return self.cells[-1].state_size
def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
- # The init state is in reverse order of cell's initial state since the
- # state_size is in reverse order. It is flattened into a list also because
- # the state_size is a flattened list.
+ # The init state is flattened into a list because state_size is a flattened
+ # list.
initial_states = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
get_initial_state_fn = getattr(cell, 'get_initial_state', None)
if get_initial_state_fn:
initial_states.append(get_initial_state_fn(
@@ -118,14 +127,15 @@ class StackedRNNCells(Layer):
def call(self, inputs, states, constants=None, **kwargs):
# Recover per-cell states.
nested_states = []
- for cell in self.cells[::-1]:
+ for cell in self.cells[::-1] if self.reverse_state_order else self.cells:
if _is_multiple_state(cell.state_size):
nested_states.append(states[:len(cell.state_size)])
states = states[len(cell.state_size):]
else:
nested_states.append([states[0]])
states = states[1:]
- nested_states = nested_states[::-1]
+ if self.reverse_state_order:
+ nested_states = nested_states[::-1]
# Call the cells in order and store the returned states.
new_nested_states = []
@@ -139,11 +149,12 @@ class StackedRNNCells(Layer):
new_nested_states.append(states)
# Format the new states as a flat list
- # in reverse cell order.
- states = []
- for cell_states in new_nested_states[::-1]:
- states += cell_states
- return inputs, states
+ new_states = []
+ if self.reverse_state_order:
+ new_nested_states = new_nested_states[::-1]
+ for cell_states in new_nested_states:
+ new_states += cell_states
+ return inputs, new_states
@tf_utils.shape_type_conversion
def build(self, input_shape):
@@ -156,7 +167,9 @@ class StackedRNNCells(Layer):
cell.build([input_shape] + constants_shape)
else:
cell.build(input_shape)
- if _is_multiple_state(cell.state_size):
+ if getattr(cell, 'output_size', None) is not None:
+ output_dim = cell.output_size
+ elif _is_multiple_state(cell.state_size):
output_dim = cell.state_size[0]
else:
output_dim = cell.state_size
@@ -659,6 +672,14 @@ class RNN(Layer):
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
+ # get initial_state from full input spec
+ # as they could be copied to multiple GPU.
+ if self._num_constants is None:
+ initial_state = inputs[1:]
+ else:
+ initial_state = inputs[1:-self._num_constants]
+ if len(initial_state) == 0:
+ initial_state = None
inputs = inputs[0]
if initial_state is not None:
pass
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index f14b36e7e1..a3861e44d5 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -50,7 +50,7 @@ class RNNTest(test.TestCase):
output = keras.backend.dot(inputs, self.kernel) + prev_output
return output, [output]
- with self.test_session():
+ with self.cached_session():
# Basic test case.
cell = MinimalRNNCell(32, 5)
x = keras.Input((None, 5))
@@ -88,7 +88,7 @@ class RNNTest(test.TestCase):
output -= prev_output_2
return output, [output * 2, output * 3]
- with self.test_session():
+ with self.cached_session():
# Basic test case.
cell = MinimalRNNCell(32, 5)
x = keras.Input((None, 5))
@@ -103,7 +103,8 @@ class RNNTest(test.TestCase):
MinimalRNNCell(16, 8),
MinimalRNNCell(32, 16)]
layer = keras.layers.RNN(cells)
- assert layer.cell.state_size == (32, 32, 16, 16, 8, 8)
+ self.assertEqual(layer.cell.state_size, (8, 8, 16, 16, 32, 32))
+ self.assertEqual(layer.cell.output_size, 32)
y = layer(x)
model = keras.models.Model(x, y)
model.compile(optimizer='rmsprop', loss='mse')
@@ -139,7 +140,7 @@ class RNNTest(test.TestCase):
base_config = super(MinimalRNNCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
cell = MinimalRNNCell(32)
@@ -228,7 +229,7 @@ class RNNTest(test.TestCase):
base_config = super(RNNCellWithConstants, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
c = keras.Input((3,))
@@ -243,7 +244,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
@@ -259,7 +260,7 @@ class RNNTest(test.TestCase):
y_np_2 = model.predict([x_np, c_np])
self.assertAllClose(y_np, y_np_2, atol=1e-4)
- with self.test_session():
+ with self.cached_session():
# test flat list inputs.
with keras.utils.CustomObjectScope(custom_objects):
layer = keras.layers.RNN.from_config(config.copy())
@@ -269,7 +270,7 @@ class RNNTest(test.TestCase):
y_np_3 = model.predict([x_np, c_np])
self.assertAllClose(y_np, y_np_3, atol=1e-4)
- with self.test_session():
+ with self.cached_session():
# Test stacking.
cells = [keras.layers.recurrent.GRUCell(8),
RNNCellWithConstants(12),
@@ -283,7 +284,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test GRUCell reset_after property.
x = keras.Input((None, 5))
c = keras.Input((3,))
@@ -297,7 +298,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test stacked RNN serialization
x_np = np.random.random((6, 5, 5))
c_np = np.random.random((6, 3))
@@ -355,7 +356,7 @@ class RNNTest(test.TestCase):
base_config = super(RNNCellWithConstants, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
c = keras.Input((3,))
@@ -370,7 +371,7 @@ class RNNTest(test.TestCase):
np.zeros((6, 32))
)
- with self.test_session():
+ with self.cached_session():
# Test basic case serialization.
x_np = np.random.random((6, 5, 5))
s_np = np.random.random((6, 32))
@@ -392,7 +393,7 @@ class RNNTest(test.TestCase):
with self.assertRaises(AssertionError):
self.assertAllClose(y_np, y_np_2_different_s, atol=1e-4)
- with self.test_session():
+ with self.cached_session():
# test flat list inputs
with keras.utils.CustomObjectScope(custom_objects):
layer = keras.layers.RNN.from_config(config.copy())
@@ -467,7 +468,7 @@ class RNNTest(test.TestCase):
timesteps = 2
num_samples = 2
- with self.test_session():
+ with self.cached_session():
input1 = keras.Input(batch_shape=(num_samples, timesteps, embedding_dim))
layer = layer_class(units,
return_state=True,
@@ -487,7 +488,7 @@ class RNNTest(test.TestCase):
for cell_class in [keras.layers.SimpleRNNCell,
keras.layers.GRUCell,
keras.layers.LSTMCell]:
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((None, 5))
cell = cell_class(32)
@@ -534,7 +535,7 @@ class RNNTest(test.TestCase):
keras.layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1)]
layer = keras.layers.RNN(cells)
- with self.test_session():
+ with self.cached_session():
x = keras.Input((None, 5))
y = layer(x)
model = keras.models.Model(x, y)
@@ -551,6 +552,21 @@ class RNNTest(test.TestCase):
layer = keras.layers.RNN(cells, return_state=True, return_sequences=True)
output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
expected_output_shape = [(None, timesteps, 6),
+ (None, 3),
+ (None, 3),
+ (None, 6),
+ (None, 6)]
+ self.assertEqual(
+ [tuple(o.as_list()) for o in output_shape],
+ expected_output_shape)
+
+ # Test reverse_state_order = True for stacked cell.
+ stacked_cell = keras.layers.StackedRNNCells(
+ cells, reverse_state_order=True)
+ layer = keras.layers.RNN(
+ stacked_cell, return_state=True, return_sequences=True)
+ output_shape = layer.compute_output_shape((None, timesteps, embedding_dim))
+ expected_output_shape = [(None, timesteps, 6),
(None, 6),
(None, 6),
(None, 3),
@@ -561,7 +577,7 @@ class RNNTest(test.TestCase):
def test_checkpointable_dependencies(self):
rnn = keras.layers.SimpleRNN
- with self.test_session():
+ with self.cached_session():
x = np.random.random((2, 2, 2))
y = np.random.random((2, 2))
model = keras.models.Sequential()
@@ -576,7 +592,7 @@ class RNNTest(test.TestCase):
self.assertIn(v, checkpointed_objects)
def test_high_dimension_RNN(self):
- with self.test_session():
+ with self.cached_session():
# Basic test case.
unit_a = 10
unit_b = 20
@@ -626,7 +642,7 @@ class RNNTest(test.TestCase):
batch = 32
time_step = 4
- with self.test_session():
+ with self.cached_session():
# Basic test case.
cell = Minimal2DRNNCell(unit_a, unit_b)
x = keras.Input((None, input_a, input_b))
@@ -642,7 +658,7 @@ class RNNTest(test.TestCase):
], np.zeros((batch, unit_a, unit_b)))
self.assertEqual(model.output_shape, (None, unit_a, unit_b))
- with self.test_session():
+ with self.cached_session():
# Bad init state shape.
bad_shape_a = unit_a * 2
bad_shape_b = unit_b * 2
@@ -655,7 +671,7 @@ class RNNTest(test.TestCase):
layer(x, initial_state=s)
def test_inconsistent_output_state_size(self):
- with self.test_session():
+ with self.cached_session():
batch = 32
time_step = 4
state_size = 5
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index 9b8d5fc5cc..a1933c11b0 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -545,11 +545,27 @@ class Bidirectional(Wrapper):
if initial_state is not None and generic_utils.has_arg(
self.layer.call, 'initial_state'):
- forward_state = initial_state[:len(initial_state) // 2]
- backward_state = initial_state[len(initial_state) // 2:]
- y = self.forward_layer.call(inputs, initial_state=forward_state, **kwargs)
- y_rev = self.backward_layer.call(
- inputs, initial_state=backward_state, **kwargs)
+ forward_inputs = [inputs[0]]
+ backward_inputs = [inputs[0]]
+ pivot = len(initial_state) // 2 + 1
+ # add forward initial state
+ forward_state = inputs[1:pivot]
+ forward_inputs += forward_state
+ if self._num_constants is None:
+ # add backward initial state
+ backward_state = inputs[pivot:]
+ backward_inputs += backward_state
+ else:
+ # add backward initial state
+ backward_state = inputs[pivot:-self._num_constants]
+ backward_inputs += backward_state
+ # add constants for forward and backward layers
+ forward_inputs += inputs[-self._num_constants:]
+ backward_inputs += inputs[-self._num_constants:]
+ y = self.forward_layer.call(forward_inputs,
+ initial_state=forward_state, **kwargs)
+ y_rev = self.backward_layer.call(backward_inputs,
+ initial_state=backward_state, **kwargs)
else:
y = self.forward_layer.call(inputs, **kwargs)
y_rev = self.backward_layer.call(inputs, **kwargs)
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index 0cd774ef0f..965960917c 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -113,7 +113,7 @@ class TimeDistributedTest(test.TestCase):
keras.layers.TimeDistributed(x)
def test_timedistributed_conv2d(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -128,7 +128,7 @@ class TimeDistributedTest(test.TestCase):
model.summary()
def test_timedistributed_stacked(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -144,7 +144,7 @@ class TimeDistributedTest(test.TestCase):
batch_size=10)
def test_regularizers(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(
keras.layers.TimeDistributed(
@@ -155,7 +155,7 @@ class TimeDistributedTest(test.TestCase):
self.assertEqual(len(model.losses), 1)
def test_TimeDistributed_learning_phase(self):
- with self.test_session():
+ with self.cached_session():
# test layers that need learning_phase to be set
np.random.seed(1234)
x = keras.layers.Input(shape=(3, 2))
@@ -166,7 +166,7 @@ class TimeDistributedTest(test.TestCase):
self.assertAllClose(np.mean(y), 0., atol=1e-1, rtol=1e-1)
def test_TimeDistributed_batchnorm(self):
- with self.test_session():
+ with self.cached_session():
# test that wrapped BN updates still work.
model = keras.models.Sequential()
model.add(keras.layers.TimeDistributed(
@@ -202,7 +202,7 @@ class TimeDistributedTest(test.TestCase):
assert len(layer.trainable_weights) == 2
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
- with self.test_session():
+ with self.cached_session():
# test with unspecified shape and Embeddings with mask_zero
model = keras.models.Sequential()
model.add(keras.layers.TimeDistributed(
@@ -234,7 +234,7 @@ class TimeDistributedTest(test.TestCase):
self.assertIs(mask_outputs[-1], None) # final layer
def test_TimeDistributed_with_masking_layer(self):
- with self.test_session():
+ with self.cached_session():
# test with Masking layer
model = keras.models.Sequential()
model.add(keras.layers.TimeDistributed(keras.layers.Masking(
@@ -266,7 +266,7 @@ class BidirectionalTest(test.TestCase):
dim = 2
timesteps = 2
output_dim = 2
- with self.test_session():
+ with self.cached_session():
for mode in ['sum', 'concat', 'ave', 'mul']:
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
@@ -310,7 +310,7 @@ class BidirectionalTest(test.TestCase):
dim = 2
timesteps = 2
output_dim = 2
- with self.test_session():
+ with self.cached_session():
x = np.random.random((samples, timesteps, dim))
model = keras.models.Sequential()
model.add(
@@ -331,7 +331,7 @@ class BidirectionalTest(test.TestCase):
output_dim = 2
mode = 'sum'
- with self.test_session():
+ with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
@@ -363,7 +363,7 @@ class BidirectionalTest(test.TestCase):
output_dim = 2
mode = 'sum'
- with self.test_session():
+ with self.cached_session():
x = np.random.random((samples, timesteps, dim))
target_dim = 2 * output_dim if mode == 'concat' else output_dim
y = np.random.random((samples, target_dim))
@@ -383,7 +383,7 @@ class BidirectionalTest(test.TestCase):
units = 3
x = [np.random.rand(samples, timesteps, dim)]
- with self.test_session():
+ with self.cached_session():
for merge_mode in ['sum', 'mul', 'ave', 'concat', None]:
if merge_mode == 'sum':
merge_func = lambda y, y_rev: y + y_rev
@@ -447,7 +447,7 @@ class BidirectionalTest(test.TestCase):
merge_mode = 'sum'
x = [np.random.rand(samples, timesteps, dim)]
- with self.test_session():
+ with self.cached_session():
inputs = keras.Input((timesteps, dim))
wrapped = keras.layers.Bidirectional(
rnn(units, dropout=0.2, recurrent_dropout=0.2), merge_mode=merge_mode)
@@ -474,7 +474,7 @@ class BidirectionalTest(test.TestCase):
timesteps = 3
units = 3
- with self.test_session():
+ with self.cached_session():
input1 = keras.layers.Input((timesteps, dim))
layer = keras.layers.Bidirectional(
rnn(units, return_state=True, return_sequences=True))
@@ -498,7 +498,7 @@ class BidirectionalTest(test.TestCase):
def test_Bidirectional_trainable(self):
# test layers that need learning_phase to be set
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3, 2))
layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
_ = layer(x)
@@ -509,7 +509,7 @@ class BidirectionalTest(test.TestCase):
assert len(layer.trainable_weights) == 6
def test_Bidirectional_updates(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3, 2))
x_reachable_update = x * x
layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3))
@@ -526,7 +526,7 @@ class BidirectionalTest(test.TestCase):
assert len(layer.get_updates_for(x)) == 2
def test_Bidirectional_losses(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3, 2))
x_reachable_loss = x * x
layer = keras.layers.Bidirectional(
@@ -545,7 +545,7 @@ class BidirectionalTest(test.TestCase):
assert len(layer.get_losses_for(x)) == 2
def test_Bidirectional_with_constants(self):
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((5, 5))
c = keras.Input((3,))
@@ -586,7 +586,7 @@ class BidirectionalTest(test.TestCase):
self.assertAllClose(y_np, y_np_3, atol=1e-4)
def test_Bidirectional_with_constants_layer_passing_initial_state(self):
- with self.test_session():
+ with self.cached_session():
# Test basic case.
x = keras.Input((5, 5))
c = keras.Input((3,))
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 3098a6d071..c7015270ac 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -63,7 +63,7 @@ class _MSEMAELoss(object):
class KerasLossesTest(test.TestCase):
def test_objective_shapes_3d(self):
- with self.test_session():
+ with self.cached_session():
y_a = keras.backend.variable(np.random.random((5, 6, 7)))
y_b = keras.backend.variable(np.random.random((5, 6, 7)))
for obj in ALL_LOSSES:
@@ -71,7 +71,7 @@ class KerasLossesTest(test.TestCase):
self.assertListEqual(objective_output.get_shape().as_list(), [5, 6])
def test_objective_shapes_2d(self):
- with self.test_session():
+ with self.cached_session():
y_a = keras.backend.variable(np.random.random((6, 7)))
y_b = keras.backend.variable(np.random.random((6, 7)))
for obj in ALL_LOSSES:
@@ -79,7 +79,7 @@ class KerasLossesTest(test.TestCase):
self.assertListEqual(objective_output.get_shape().as_list(), [6,])
def test_cce_one_hot(self):
- with self.test_session():
+ with self.cached_session():
y_a = keras.backend.variable(np.random.randint(0, 7, (5, 6)))
y_b = keras.backend.variable(np.random.random((5, 6, 7)))
objective_output = keras.losses.sparse_categorical_crossentropy(y_a, y_b)
@@ -119,7 +119,7 @@ class KerasLossesTest(test.TestCase):
self.addCleanup(shutil.rmtree, tmpdir)
model_filename = os.path.join(tmpdir, 'custom_loss.h5')
- with self.test_session():
+ with self.cached_session():
with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
loss = _MSEMAELoss(0.3)
inputs = keras.layers.Input((2,))
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index 0983d62c59..14cf1ce2af 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -53,11 +53,12 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.tf_export import tf_export
+from tensorflow.tools.docs import doc_controls
def check_is_tensor_or_operation(x, name):
@@ -388,11 +389,12 @@ class Metric(Layer):
return cls(**config)
### For use by subclasses ###
+ @doc_controls.for_subclass_implementers
def add_weight(self,
name,
shape=(),
- aggregation=vs.VariableAggregation.SUM,
- synchronization=vs.VariableSynchronization.ON_READ,
+ aggregation=tf_variables.VariableAggregation.SUM,
+ synchronization=tf_variables.VariableSynchronization.ON_READ,
initializer=None):
"""Adds state variable. Only for use by subclasses."""
return super(Metric, self).add_weight(
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 58c55136b4..0bc95a3952 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -40,7 +40,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable_uti
class KerasMetricsTest(test.TestCase):
def test_metrics(self):
- with self.test_session():
+ with self.cached_session():
y_a = K.variable(np.random.random((6, 7)))
y_b = K.variable(np.random.random((6, 7)))
for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]:
@@ -48,14 +48,14 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(K.eval(output).shape, (6,))
def test_sparse_categorical_accuracy(self):
- with self.test_session():
+ with self.cached_session():
metric = metrics.sparse_categorical_accuracy
y_a = K.variable(np.random.randint(0, 7, (6,)))
y_b = K.variable(np.random.random((6, 7)))
self.assertEqual(K.eval(metric(y_a, y_b)).shape, (6,))
def test_sparse_top_k_categorical_accuracy(self):
- with self.test_session():
+ with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[1], [0]]))
result = K.eval(
@@ -69,7 +69,7 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(result, 0.)
def test_top_k_categorical_accuracy(self):
- with self.test_session():
+ with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
@@ -80,7 +80,7 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(result, 0.)
def test_stateful_metrics(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1334)
class BinaryTruePositives(layers.Layer):
@@ -266,7 +266,7 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(np.round(self.evaluate(m.count), decimals=2), 5.6)
def test_mean_graph_with_placeholder(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
m = metrics.Mean()
v = array_ops.placeholder(dtypes.float32)
w = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index e3032acbfd..39b6042597 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -33,6 +33,7 @@ from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.util.tf_export import tf_export
# API entries importable from `keras.models`:
Model = training.Model # pylint: disable=invalid-name
@@ -226,6 +227,7 @@ def _clone_sequential_model(model, input_tensors=None):
return Sequential(layers=[input_layer] + layers, name=model.name)
+@tf_export('keras.models.clone_model')
def clone_model(model, input_tensors=None):
"""Clone any `Model` instance.
@@ -447,6 +449,7 @@ def clone_and_build_model(
elif model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
optimizer = model.optimizer
+ K.track_tf_optimizer(optimizer)
else:
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py
index 5f755f7b5e..1d0f56f3c8 100644
--- a/tensorflow/python/keras/models_test.py
+++ b/tensorflow/python/keras/models_test.py
@@ -18,18 +18,36 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import os
import numpy as np
from tensorflow.python import keras
+from tensorflow.python.eager import context
from tensorflow.python.framework import test_util
from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
from tensorflow.python.training import adam
+class TestModel(keras.Model):
+ """A model subclass."""
+
+ def __init__(self, n_outputs=4, trainable=True):
+ """A test class with one dense layer and number of outputs as a variable."""
+ super(TestModel, self).__init__()
+ self.layer1 = keras.layers.Dense(n_outputs)
+ self.n_outputs = resource_variable_ops.ResourceVariable(
+ n_outputs, trainable=trainable)
+
+ def call(self, x):
+ return self.layer1(x)
+
+
class TestModelCloning(test.TestCase):
def test_clone_sequential_model(self):
@@ -187,6 +205,36 @@ class TestModelBackend(test.TestCase):
keras.backend.set_floatx(floatx)
+class TestModelDeepCopy(test.TestCase):
+
+ def test_deep_copy_eager_mode_trainable(self):
+ with context.eager_mode():
+ x = random_ops.random_normal((32, 4))
+ model = TestModel(trainable=True)
+ model(x) # Initialize Variables.
+ model_copy = copy.deepcopy(model)
+ self.assertEqual(len(model_copy.trainable_variables), 3)
+ model_copy.n_outputs.assign(1200)
+ self.assertFalse(
+ np.allclose(model_copy.n_outputs.numpy(),
+ model.n_outputs.numpy()))
+
+ def test_deep_copy_eager_mode_not_trainable(self):
+ with context.eager_mode():
+ x = random_ops.random_normal((32, 4))
+ model = TestModel(trainable=False)
+ model(x)
+ model_copy = copy.deepcopy(model)
+ self.assertEqual(len(model_copy.trainable_variables), 2)
+
+ weights = model_copy.get_weights()
+ weights = [w * 4 for w in weights]
+ model_copy.set_weights(weights)
+ self.assertFalse(
+ np.allclose(model.get_weights()[0],
+ model_copy.get_weights()[0]))
+
+
class TestCloneAndBuildModel(test.TestCase):
def test_clone_and_build_non_compiled_model(self):
diff --git a/tensorflow/python/keras/preprocessing/__init__.py b/tensorflow/python/keras/preprocessing/__init__.py
index 2f08f88600..0860eed3cf 100644
--- a/tensorflow/python/keras/preprocessing/__init__.py
+++ b/tensorflow/python/keras/preprocessing/__init__.py
@@ -23,6 +23,8 @@ import keras_preprocessing
from tensorflow.python.keras import backend
from tensorflow.python.keras import utils
+# This exists for compatibility with prior version of keras_preprocessing.
+# TODO(fchollet): remove in the future.
keras_preprocessing.set_keras_submodules(backend=backend, utils=utils)
from tensorflow.python.keras.preprocessing import image
diff --git a/tensorflow/python/keras/preprocessing/image.py b/tensorflow/python/keras/preprocessing/image.py
index ba227385ef..e33993950d 100644
--- a/tensorflow/python/keras/preprocessing/image.py
+++ b/tensorflow/python/keras/preprocessing/image.py
@@ -27,6 +27,9 @@ try:
except ImportError:
pass
+from tensorflow.python.keras import backend
+from tensorflow.python.keras import utils
+from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
random_rotation = image.random_rotation
@@ -38,14 +41,482 @@ random_channel_shift = image.random_channel_shift
apply_brightness_shift = image.apply_brightness_shift
random_brightness = image.random_brightness
apply_affine_transform = image.apply_affine_transform
-array_to_img = image.array_to_img
-img_to_array = image.img_to_array
-save_img = image.save_img
load_img = image.load_img
-ImageDataGenerator = image.ImageDataGenerator
-Iterator = image.Iterator
-NumpyArrayIterator = image.NumpyArrayIterator
-DirectoryIterator = image.DirectoryIterator
+
+
+@tf_export('keras.preprocessing.image.array_to_img')
+def array_to_img(x, data_format=None, scale=True, dtype=None):
+ """Converts a 3D Numpy array to a PIL Image instance.
+
+ Arguments:
+ x: Input Numpy array.
+ data_format: Image data format.
+ either "channels_first" or "channels_last".
+ scale: Whether to rescale image values
+ to be within `[0, 255]`.
+ dtype: Dtype to use.
+
+ Returns:
+ A PIL Image instance.
+
+ Raises:
+ ImportError: if PIL is not available.
+ ValueError: if invalid `x` or `data_format` is passed.
+ """
+
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(image.array_to_img)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ return image.array_to_img(x, data_format=data_format, scale=scale, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.img_to_array')
+def img_to_array(img, data_format=None, dtype=None):
+ """Converts a PIL Image instance to a Numpy array.
+
+ Arguments:
+ img: PIL Image instance.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ dtype: Dtype to use for the returned array.
+
+ Returns:
+ A 3D Numpy array.
+
+ Raises:
+ ValueError: if invalid `img` or `data_format` is passed.
+ """
+
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(image.img_to_array)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ return image.img_to_array(img, data_format=data_format, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.save_img')
+def save_img(path,
+ x,
+ data_format=None,
+ file_format=None,
+ scale=True,
+ **kwargs):
+ """Saves an image stored as a Numpy array to a path or file object.
+
+ Arguments:
+ path: Path or file object.
+ x: Numpy array.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ file_format: Optional file format override. If omitted, the
+ format to use is determined from the filename extension.
+ If a file object was used instead of a filename, this
+ parameter should always be used.
+ scale: Whether to rescale image values to be within `[0, 255]`.
+ **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.
+ """
+ if data_format is None:
+ data_format = backend.image_data_format()
+ image.save_img(path,
+ x,
+ data_format=data_format,
+ file_format=file_format,
+ scale=scale, **kwargs)
+
+
+@tf_export('keras.preprocessing.image.Iterator')
+class Iterator(image.Iterator, utils.Sequence):
+ pass
+
+
+@tf_export('keras.preprocessing.image.DirectoryIterator')
+class DirectoryIterator(image.DirectoryIterator, Iterator):
+ """Iterator capable of reading images from a directory on disk.
+
+ Arguments:
+ directory: Path to the directory to read images from.
+ Each subdirectory in this directory will be
+ considered to contain images from one class,
+ or alternatively you could specify class subdirectories
+ via the `classes` argument.
+ image_data_generator: Instance of `ImageDataGenerator`
+ to use for random transformations and normalization.
+ target_size: tuple of integers, dimensions to resize input images to.
+ color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
+ Color mode to read images.
+ classes: Optional list of strings, names of subdirectories
+ containing images from each class (e.g. `["dogs", "cats"]`).
+ It will be computed automatically if not set.
+ class_mode: Mode for yielding the targets:
+ `"binary"`: binary targets (if there are only two classes),
+ `"categorical"`: categorical targets,
+ `"sparse"`: integer targets,
+ `"input"`: targets are images identical to input images (mainly
+ used to work with autoencoders),
+ `None`: no targets get yielded (only input images are yielded).
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ seed: Random seed for data shuffling.
+ data_format: String, one of `channels_first`, `channels_last`.
+ save_to_dir: Optional directory where to save the pictures
+ being yielded, in a viewable format. This is useful
+ for visualizing the random transformations being
+ applied, for debugging purposes.
+ save_prefix: String prefix to use for saving sample
+ images (if `save_to_dir` is set).
+ save_format: Format to use for saving sample images
+ (if `save_to_dir` is set).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ validation_split is set in ImageDataGenerator.
+ interpolation: Interpolation method used to resample the image if the
+ target size is different from that of the loaded image.
+ Supported methods are "nearest", "bilinear", and "bicubic".
+ If PIL version 1.1.3 or newer is installed, "lanczos" is also
+ supported. If PIL version 3.4.0 or newer is installed, "box" and
+ "hamming" are also supported. By default, "nearest" is used.
+ dtype: Dtype to use for generated arrays.
+ """
+
+ def __init__(self, directory, image_data_generator,
+ target_size=(256, 256),
+ color_mode='rgb',
+ classes=None,
+ class_mode='categorical',
+ batch_size=32,
+ shuffle=True,
+ seed=None,
+ data_format=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ follow_links=False,
+ subset=None,
+ interpolation='nearest',
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.ImageDataGenerator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(DirectoryIterator, self).__init__(
+ directory, image_data_generator,
+ target_size=target_size,
+ color_mode=color_mode,
+ classes=classes,
+ class_mode=class_mode,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ seed=seed,
+ data_format=data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ follow_links=follow_links,
+ subset=subset,
+ interpolation=interpolation,
+ **kwargs)
+
+
+@tf_export('keras.preprocessing.image.NumpyArrayIterator')
+class NumpyArrayIterator(image.NumpyArrayIterator, Iterator):
+ """Iterator yielding data from a Numpy array.
+
+ Arguments:
+ x: Numpy array of input data or tuple.
+ If tuple, the second elements is either
+ another numpy array or a list of numpy arrays,
+ each of which gets passed
+ through as an output without any modifications.
+ y: Numpy array of targets data.
+ image_data_generator: Instance of `ImageDataGenerator`
+ to use for random transformations and normalization.
+ batch_size: Integer, size of a batch.
+ shuffle: Boolean, whether to shuffle the data between epochs.
+ sample_weight: Numpy array of sample weights.
+ seed: Random seed for data shuffling.
+ data_format: String, one of `channels_first`, `channels_last`.
+ save_to_dir: Optional directory where to save the pictures
+ being yielded, in a viewable format. This is useful
+ for visualizing the random transformations being
+ applied, for debugging purposes.
+ save_prefix: String prefix to use for saving sample
+ images (if `save_to_dir` is set).
+ save_format: Format to use for saving sample images
+ (if `save_to_dir` is set).
+ subset: Subset of data (`"training"` or `"validation"`) if
+ validation_split is set in ImageDataGenerator.
+ dtype: Dtype to use for the generated arrays.
+ """
+
+ def __init__(self, x, y, image_data_generator,
+ batch_size=32,
+ shuffle=False,
+ sample_weight=None,
+ seed=None,
+ data_format=None,
+ save_to_dir=None,
+ save_prefix='',
+ save_format='png',
+ subset=None,
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.NumpyArrayIterator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(NumpyArrayIterator, self).__init__(
+ x, y, image_data_generator,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ sample_weight=sample_weight,
+ seed=seed,
+ data_format=data_format,
+ save_to_dir=save_to_dir,
+ save_prefix=save_prefix,
+ save_format=save_format,
+ subset=subset,
+ **kwargs)
+
+
+@tf_export('keras.preprocessing.image.ImageDataGenerator')
+class ImageDataGenerator(image.ImageDataGenerator):
+ """Generate batches of tensor image data with real-time data augmentation.
+
+ The data will be looped over (in batches).
+
+ Arguments:
+ featurewise_center: Boolean.
+ Set input mean to 0 over the dataset, feature-wise.
+ samplewise_center: Boolean. Set each sample mean to 0.
+ featurewise_std_normalization: Boolean.
+ Divide inputs by std of the dataset, feature-wise.
+ samplewise_std_normalization: Boolean. Divide each input by its std.
+ zca_epsilon: epsilon for ZCA whitening. Default is 1e-6.
+ zca_whitening: Boolean. Apply ZCA whitening.
+ rotation_range: Int. Degree range for random rotations.
+ width_shift_range: Float, 1-D array-like or int
+ - float: fraction of total width, if < 1, or pixels if >= 1.
+ - 1-D array-like: random elements from the array.
+ - int: integer number of pixels from interval
+ `(-width_shift_range, +width_shift_range)`
+ - With `width_shift_range=2` possible values
+ are integers `[-1, 0, +1]`,
+ same as with `width_shift_range=[-1, 0, +1]`,
+ while with `width_shift_range=1.0` possible values are floats
+ in the interval [-1.0, +1.0).
+ height_shift_range: Float, 1-D array-like or int
+ - float: fraction of total height, if < 1, or pixels if >= 1.
+ - 1-D array-like: random elements from the array.
+ - int: integer number of pixels from interval
+ `(-height_shift_range, +height_shift_range)`
+ - With `height_shift_range=2` possible values
+ are integers `[-1, 0, +1]`,
+ same as with `height_shift_range=[-1, 0, +1]`,
+ while with `height_shift_range=1.0` possible values are floats
+ in the interval [-1.0, +1.0).
+ brightness_range: Tuple or list of two floats. Range for picking
+ a brightness shift value from.
+ shear_range: Float. Shear Intensity
+ (Shear angle in counter-clockwise direction in degrees)
+ zoom_range: Float or [lower, upper]. Range for random zoom.
+ If a float, `[lower, upper] = [1-zoom_range, 1+zoom_range]`.
+ channel_shift_range: Float. Range for random channel shifts.
+ fill_mode: One of {"constant", "nearest", "reflect" or "wrap"}.
+ Default is 'nearest'.
+ Points outside the boundaries of the input are filled
+ according to the given mode:
+ - 'constant': kkkkkkkk|abcd|kkkkkkkk (cval=k)
+ - 'nearest': aaaaaaaa|abcd|dddddddd
+ - 'reflect': abcddcba|abcd|dcbaabcd
+ - 'wrap': abcdabcd|abcd|abcdabcd
+ cval: Float or Int.
+ Value used for points outside the boundaries
+ when `fill_mode = "constant"`.
+ horizontal_flip: Boolean. Randomly flip inputs horizontally.
+ vertical_flip: Boolean. Randomly flip inputs vertically.
+ rescale: rescaling factor. Defaults to None.
+ If None or 0, no rescaling is applied,
+ otherwise we multiply the data by the value provided
+ (after applying all other transformations).
+ preprocessing_function: function that will be implied on each input.
+ The function will run after the image is resized and augmented.
+ The function should take one argument:
+ one image (Numpy tensor with rank 3),
+ and should output a Numpy tensor with the same shape.
+ data_format: Image data format,
+ either "channels_first" or "channels_last".
+ "channels_last" mode means that the images should have shape
+ `(samples, height, width, channels)`,
+ "channels_first" mode means that the images should have shape
+ `(samples, channels, height, width)`.
+ It defaults to the `image_data_format` value found in your
+ Keras config file at `~/.keras/keras.json`.
+ If you never set it, then it will be "channels_last".
+ validation_split: Float. Fraction of images reserved for validation
+ (strictly between 0 and 1).
+ dtype: Dtype to use for the generated arrays.
+
+ Examples:
+
+ Example of using `.flow(x, y)`:
+
+ ```python
+ (x_train, y_train), (x_test, y_test) = cifar10.load_data()
+ y_train = np_utils.to_categorical(y_train, num_classes)
+ y_test = np_utils.to_categorical(y_test, num_classes)
+ datagen = ImageDataGenerator(
+ featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=20,
+ width_shift_range=0.2,
+ height_shift_range=0.2,
+ horizontal_flip=True)
+ # compute quantities required for featurewise normalization
+ # (std, mean, and principal components if ZCA whitening is applied)
+ datagen.fit(x_train)
+ # fits the model on batches with real-time data augmentation:
+ model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
+ steps_per_epoch=len(x_train) / 32, epochs=epochs)
+ # here's a more "manual" example
+ for e in range(epochs):
+ print('Epoch', e)
+ batches = 0
+ for x_batch, y_batch in datagen.flow(x_train, y_train, batch_size=32):
+ model.fit(x_batch, y_batch)
+ batches += 1
+ if batches >= len(x_train) / 32:
+ # we need to break the loop by hand because
+ # the generator loops indefinitely
+ break
+ ```
+
+ Example of using `.flow_from_directory(directory)`:
+
+ ```python
+ train_datagen = ImageDataGenerator(
+ rescale=1./255,
+ shear_range=0.2,
+ zoom_range=0.2,
+ horizontal_flip=True)
+ test_datagen = ImageDataGenerator(rescale=1./255)
+ train_generator = train_datagen.flow_from_directory(
+ 'data/train',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ validation_generator = test_datagen.flow_from_directory(
+ 'data/validation',
+ target_size=(150, 150),
+ batch_size=32,
+ class_mode='binary')
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50,
+ validation_data=validation_generator,
+ validation_steps=800)
+ ```
+
+ Example of transforming images and masks together.
+
+ ```python
+ # we create two instances with the same arguments
+ data_gen_args = dict(featurewise_center=True,
+ featurewise_std_normalization=True,
+ rotation_range=90,
+ width_shift_range=0.1,
+ height_shift_range=0.1,
+ zoom_range=0.2)
+ image_datagen = ImageDataGenerator(**data_gen_args)
+ mask_datagen = ImageDataGenerator(**data_gen_args)
+ # Provide the same seed and keyword arguments to the fit and flow methods
+ seed = 1
+ image_datagen.fit(images, augment=True, seed=seed)
+ mask_datagen.fit(masks, augment=True, seed=seed)
+ image_generator = image_datagen.flow_from_directory(
+ 'data/images',
+ class_mode=None,
+ seed=seed)
+ mask_generator = mask_datagen.flow_from_directory(
+ 'data/masks',
+ class_mode=None,
+ seed=seed)
+ # combine generators into one which yields image and masks
+ train_generator = zip(image_generator, mask_generator)
+ model.fit_generator(
+ train_generator,
+ steps_per_epoch=2000,
+ epochs=50)
+ ```
+ """
+
+ def __init__(self,
+ featurewise_center=False,
+ samplewise_center=False,
+ featurewise_std_normalization=False,
+ samplewise_std_normalization=False,
+ zca_whitening=False,
+ zca_epsilon=1e-6,
+ rotation_range=0,
+ width_shift_range=0.,
+ height_shift_range=0.,
+ brightness_range=None,
+ shear_range=0.,
+ zoom_range=0.,
+ channel_shift_range=0.,
+ fill_mode='nearest',
+ cval=0.,
+ horizontal_flip=False,
+ vertical_flip=False,
+ rescale=None,
+ preprocessing_function=None,
+ data_format=None,
+ validation_split=0.0,
+ dtype=None):
+ if data_format is None:
+ data_format = backend.image_data_format()
+ kwargs = {}
+ if 'dtype' in tf_inspect.getfullargspec(
+ image.ImageDataGenerator.__init__)[0]:
+ if dtype is None:
+ dtype = backend.floatx()
+ kwargs['dtype'] = dtype
+ super(ImageDataGenerator, self).__init__(
+ featurewise_center=featurewise_center,
+ samplewise_center=samplewise_center,
+ featurewise_std_normalization=featurewise_std_normalization,
+ samplewise_std_normalization=samplewise_std_normalization,
+ zca_whitening=zca_whitening,
+ zca_epsilon=zca_epsilon,
+ rotation_range=rotation_range,
+ width_shift_range=width_shift_range,
+ height_shift_range=height_shift_range,
+ brightness_range=brightness_range,
+ shear_range=shear_range,
+ zoom_range=zoom_range,
+ channel_shift_range=channel_shift_range,
+ fill_mode=fill_mode,
+ cval=cval,
+ horizontal_flip=horizontal_flip,
+ vertical_flip=vertical_flip,
+ rescale=rescale,
+ preprocessing_function=preprocessing_function,
+ data_format=data_format,
+ validation_split=validation_split,
+ **kwargs)
tf_export('keras.preprocessing.image.random_rotation')(random_rotation)
tf_export('keras.preprocessing.image.random_shift')(random_shift)
@@ -59,11 +530,4 @@ tf_export(
tf_export('keras.preprocessing.image.random_brightness')(random_brightness)
tf_export(
'keras.preprocessing.image.apply_affine_transform')(apply_affine_transform)
-tf_export('keras.preprocessing.image.array_to_img')(array_to_img)
-tf_export('keras.preprocessing.image.img_to_array')(img_to_array)
-tf_export('keras.preprocessing.image.save_img')(save_img)
tf_export('keras.preprocessing.image.load_img')(load_img)
-tf_export('keras.preprocessing.image.ImageDataGenerator')(ImageDataGenerator)
-tf_export('keras.preprocessing.image.Iterator')(Iterator)
-tf_export('keras.preprocessing.image.NumpyArrayIterator')(NumpyArrayIterator)
-tf_export('keras.preprocessing.image.DirectoryIterator')(DirectoryIterator)
diff --git a/tensorflow/python/keras/preprocessing/sequence.py b/tensorflow/python/keras/preprocessing/sequence.py
index 116d3108d9..f014668909 100644
--- a/tensorflow/python/keras/preprocessing/sequence.py
+++ b/tensorflow/python/keras/preprocessing/sequence.py
@@ -21,6 +21,7 @@ from __future__ import print_function
from keras_preprocessing import sequence
+from tensorflow.python.keras import utils
from tensorflow.python.util.tf_export import tf_export
pad_sequences = sequence.pad_sequences
@@ -28,11 +29,67 @@ make_sampling_table = sequence.make_sampling_table
skipgrams = sequence.skipgrams
# TODO(fchollet): consider making `_remove_long_seq` public.
_remove_long_seq = sequence._remove_long_seq # pylint: disable=protected-access
-TimeseriesGenerator = sequence.TimeseriesGenerator
+
+
+@tf_export('keras.preprocessing.sequence.TimeseriesGenerator')
+class TimeseriesGenerator(sequence.TimeseriesGenerator, utils.Sequence):
+ """Utility class for generating batches of temporal data.
+ This class takes in a sequence of data-points gathered at
+ equal intervals, along with time series parameters such as
+ stride, length of history, etc., to produce batches for
+ training/validation.
+ # Arguments
+ data: Indexable generator (such as list or Numpy array)
+ containing consecutive data points (timesteps).
+ The data should be at 2D, and axis 0 is expected
+ to be the time dimension.
+ targets: Targets corresponding to timesteps in `data`.
+ It should have same length as `data`.
+ length: Length of the output sequences (in number of timesteps).
+ sampling_rate: Period between successive individual timesteps
+ within sequences. For rate `r`, timesteps
+ `data[i]`, `data[i-r]`, ... `data[i - length]`
+ are used for create a sample sequence.
+ stride: Period between successive output sequences.
+ For stride `s`, consecutive output samples would
+ be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.
+ start_index: Data points earlier than `start_index` will not be used
+ in the output sequences. This is useful to reserve part of the
+ data for test or validation.
+ end_index: Data points later than `end_index` will not be used
+ in the output sequences. This is useful to reserve part of the
+ data for test or validation.
+ shuffle: Whether to shuffle output samples,
+ or instead draw them in chronological order.
+ reverse: Boolean: if `true`, timesteps in each output sample will be
+ in reverse chronological order.
+ batch_size: Number of timeseries samples in each batch
+ (except maybe the last one).
+ # Returns
+ A [Sequence](/utils/#sequence) instance.
+ # Examples
+ ```python
+ from keras.preprocessing.sequence import TimeseriesGenerator
+ import numpy as np
+ data = np.array([[i] for i in range(50)])
+ targets = np.array([[i] for i in range(50)])
+ data_gen = TimeseriesGenerator(data, targets,
+ length=10, sampling_rate=2,
+ batch_size=2)
+ assert len(data_gen) == 20
+ batch_0 = data_gen[0]
+ x, y = batch_0
+ assert np.array_equal(x,
+ np.array([[[0], [2], [4], [6], [8]],
+ [[1], [3], [5], [7], [9]]]))
+ assert np.array_equal(y,
+ np.array([[10], [11]]))
+ ```
+ """
+ pass
+
tf_export('keras.preprocessing.sequence.pad_sequences')(pad_sequences)
tf_export(
'keras.preprocessing.sequence.make_sampling_table')(make_sampling_table)
tf_export('keras.preprocessing.sequence.skipgrams')(skipgrams)
-tf_export(
- 'keras.preprocessing.sequence.TimeseriesGenerator')(TimeseriesGenerator)
diff --git a/tensorflow/python/keras/regularizers_test.py b/tensorflow/python/keras/regularizers_test.py
index e2075785d8..bba4ebb287 100644
--- a/tensorflow/python/keras/regularizers_test.py
+++ b/tensorflow/python/keras/regularizers_test.py
@@ -50,7 +50,7 @@ def create_model(kernel_regularizer=None, activity_regularizer=None):
class KerasRegularizersTest(test.TestCase):
def test_kernel_regularization(self):
- with self.test_session():
+ with self.cached_session():
(x_train, y_train), _ = get_data()
for reg in [keras.regularizers.l1(),
keras.regularizers.l2(),
@@ -62,7 +62,7 @@ class KerasRegularizersTest(test.TestCase):
epochs=1, verbose=0)
def test_activity_regularization(self):
- with self.test_session():
+ with self.cached_session():
(x_train, y_train), _ = get_data()
for reg in [keras.regularizers.l1(), keras.regularizers.l2()]:
model = create_model(activity_regularizer=reg)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 77792d14f5..c7e94998b4 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -180,6 +180,23 @@ class TestMultiGPUModel(test.TestCase):
target_tensors=[targets])
parallel_model.fit(epochs=1, steps_per_epoch=3)
+ def test_multi_gpu_with_multi_input_layers(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.test_session():
+ inputs = keras.Input((4, 3))
+ init_state = keras.Input((3,))
+ outputs = keras.layers.SimpleRNN(
+ 3, return_sequences=True)(inputs, initial_state=init_state)
+ x = [np.random.randn(2, 4, 3), np.random.randn(2, 3)]
+ y = np.random.randn(2, 4, 3)
+ model = keras.Model([inputs, init_state], outputs)
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus)
+ parallel_model.compile(loss='mean_squared_error', optimizer='adam')
+ parallel_model.train_on_batch(x, y)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index b9c5f26cb7..7671da11ab 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -601,7 +601,7 @@ tf_py_test(
tf_py_test(
name = "matrix_logarithm_op_test",
- size = "small",
+ size = "medium",
srcs = ["matrix_logarithm_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -1388,6 +1388,8 @@ cuda_py_test(
"//tensorflow/python/eager:context",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
@@ -1440,6 +1442,7 @@ cuda_py_test(
"//tensorflow/python:array_ops_gen",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:cond_v2",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:data_flow_ops_gen",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 81442d12e9..b0e24e969c 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -559,6 +559,14 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
s = array_ops.strided_slice(x, begin, end, strides)
self.assertAllEqual([3.], self.evaluate(s))
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerMemory(self):
+ with context.eager_mode():
+ inputs = constant_op.constant(
+ [[[1], [2], [3], [4]]], dtype=dtypes.float32)
+ # Tests that slicing an EagerTensor doesn't leak memory
+ inputs[0] # pylint: disable=pointless-statement
+
def testDegenerateSlices(self):
with self.test_session(use_gpu=True):
checker = StridedSliceChecker(self, StridedSliceChecker.REF_TENSOR)
@@ -1145,7 +1153,7 @@ class IdentityTest(test_util.TensorFlowTestCase):
def testEagerIdentity(self):
with context.eager_mode():
- ctx = context.get_default_context()
+ ctx = context.context()
if not ctx.num_gpus():
self.skipTest("No GPUs found")
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index bda6ca5ca9..05f998d0d2 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -18,8 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
import numpy as np
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -29,6 +33,8 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -745,6 +751,146 @@ class AssertPositiveTest(test.TestCase):
self.evaluate(out)
+class EnsureShapeTest(test.TestCase):
+
+ # Static shape inference
+ def testStaticShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ ensure_shape_op = check_ops.ensure_shape(placeholder, (3, 3, 3))
+ self.assertEqual(ensure_shape_op.get_shape(), (3, 3, 3))
+
+ def testStaticShape_MergesShapes(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ ensure_shape_op = check_ops.ensure_shape(placeholder, (5, 4, None))
+ self.assertEqual(ensure_shape_op.get_shape(), (5, 4, 3))
+
+ def testStaticShape_RaisesErrorWhenRankIncompatible(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ with self.assertRaises(ValueError):
+ check_ops.ensure_shape(placeholder, (2, 3))
+
+ def testStaticShape_RaisesErrorWhenDimIncompatible(self):
+ placeholder = array_ops.placeholder(dtypes.int32, shape=(None, None, 3))
+ with self.assertRaises(ValueError):
+ check_ops.ensure_shape(placeholder, (2, 2, 4))
+
+ def testStaticShape_CanSetUnknownShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ ensure_shape_op = check_ops.ensure_shape(derived, None)
+ self.assertEqual(ensure_shape_op.get_shape(), None)
+
+ # Dynamic shape check
+ def testEnsuresDynamicShape_RaisesError(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = math_ops.divide(placeholder, 3, name="MyDivide")
+ derived = check_ops.ensure_shape(derived, (3, 3, 3))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ r"Shape of tensor MyDivide \[2,1\] is not compatible with "
+ r"expected shape \[3,3,3\]."):
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (None, None, 3))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ r"Shape of tensor [A-Za-z_]* \[2,1\] is not compatible with "
+ r"expected shape \[\?,\?,3\]."):
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (2, 1))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+ def testEnsuresDynamicShape_WithUnknownDims(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ derived = placeholder / 3
+ derived = check_ops.ensure_shape(derived, (None, None))
+ feed_val = [[1], [2]]
+ with self.test_session() as sess:
+ sess.run(derived, feed_dict={placeholder: feed_val})
+
+
+class EnsureShapeBenchmark(test.Benchmark):
+
+ def _grappler_all_off_config(self):
+ config = config_pb2.ConfigProto()
+ off = rewriter_config_pb2.RewriterConfig.OFF
+ config.graph_options.optimizer_options.opt_level = -1
+ config.graph_options.rewrite_options.disable_model_pruning = 1
+ config.graph_options.rewrite_options.constant_folding = off
+ config.graph_options.rewrite_options.layout_optimizer = off
+ config.graph_options.rewrite_options.arithmetic_optimization = off
+ config.graph_options.rewrite_options.dependency_optimization = off
+ return config
+
+ def _run(self, op, feed_dict=None, num_iters=5000, name=None, **kwargs):
+ config = self._grappler_all_off_config()
+ with session.Session(config=config) as sess:
+ deltas = []
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op, feed_dict=feed_dict)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(op, feed_dict=feed_dict)
+ end = time.time()
+ deltas.append(end - start)
+ mean_time = np.median(deltas)
+ mean_us = mean_time * 1e6
+ # mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ wall_time=mean_us,
+ extras=kwargs,
+ )
+
+ def benchmark_const_op(self):
+ # In this case, we expect that the overhead of a `session.run` call
+ # far outweighs the time taken to execute the op...
+ shape = (3, 3, 100)
+ input_op = random_ops.random_normal(shape)
+ self._run(array_ops.identity(input_op), name="SingleConstOp")
+
+ def benchmark_single_ensure_op(self):
+ # In this case, we expect that the overhead of a `session.run` call
+ # far outweighs the time taken to execute the op...
+ shape = (3, 3, 100)
+ input_op = random_ops.random_normal(shape)
+ ensure_shape_op = check_ops.ensure_shape(input_op, shape)
+ self._run(ensure_shape_op, name="SingleEnsureShapeOp")
+
+ def _apply_n_times(self, op, target, n=1000):
+ for _ in range(n):
+ target = op(target)
+ return target
+
+ def benchmark_n_ops(self):
+ shape = (1000,)
+ input_op = random_ops.random_normal(shape)
+ n_ops = self._apply_n_times(array_ops.identity, input_op)
+ self._run(n_ops, name="NIdentityOps_1000")
+
+ def benchmark_n_ensure_ops(self):
+ shape = (1000,)
+ input_op = random_ops.random_normal(shape)
+ n_ensure_ops = self._apply_n_times(
+ lambda x: check_ops.ensure_shape(array_ops.identity(x), shape),
+ input_op)
+ self._run(n_ensure_ops, name="NEnsureShapeAndIdentityOps_1000")
+
+
class AssertRankTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index b9910133d8..0dc3c53bc0 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -20,9 +20,9 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2
@@ -158,7 +158,7 @@ class CondV2Test(test.TestCase):
def true_fn():
- @function.Defun()
+ @function.defun
def fn():
return x * y * 2.0
@@ -172,6 +172,8 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
+ self.skipTest("b/110550782")
+
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -180,10 +182,10 @@ class CondV2Test(test.TestCase):
def false_fn():
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def nested_fn():
return x * y * 2.0
@@ -196,18 +198,20 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
+ self.skipTest("b/110550782")
+
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
def true_fn():
- @function.Defun()
+ @function.defun
def fn():
- @function.Defun()
+ @function.defun
def nested_fn():
- @function.Defun()
+ @function.defun
def nested_nested_fn():
return x * y * 2.0
@@ -368,7 +372,7 @@ class CondV2Test(test.TestCase):
pred_outer, true_fn, false_fn, name="outer_cond")
# Compute grads inside a Defun.
- @function.Defun()
+ @function.defun
def nesting_fn():
return gradients_impl.gradients(cond_outer, [x, y])
@@ -426,10 +430,10 @@ class CondV2Test(test.TestCase):
pred_outer, true_fn, false_fn, name="outer_cond")
# Compute grads inside a Defun.
- @function.Defun()
+ @function.defun
def nesting_fn():
- @function.Defun()
+ @function.defun
def inner_nesting_fn():
return gradients_impl.gradients(cond_outer, [x, y])
@@ -464,6 +468,7 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
+ self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -472,7 +477,7 @@ class CondV2Test(test.TestCase):
y = constant_op.constant(2.0, name="y")
# Build cond and its gradient inside a Defun.
- @function.Defun()
+ @function.defun
def fn():
def true_fn():
@@ -718,6 +723,7 @@ class CondV2ContainerTest(test.TestCase):
Make sure the containers are set correctly for both variable creation
(tested by variables.Variable) and for stateful ops (tested by FIFOQueue)
"""
+ self.skipTest("b/113048653")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -795,6 +801,7 @@ class CondV2ContainerTest(test.TestCase):
class CondV2ColocationGroupAndDeviceTest(test.TestCase):
def testColocateWithBeforeCond(self):
+ self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -819,6 +826,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertEquals(cond_v2.cond_v2(True, fn2, fn2)[0].eval(), 3)
def testColocateWithInAndOutOfCond(self):
+ self.skipTest("b/112414483")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
@@ -866,6 +874,7 @@ class CondV2ColocationGroupAndDeviceTest(test.TestCase):
self.assertTrue(len(run_metadata.partition_graphs) >= 2)
def testDeviceBeforeCond(self):
+ self.skipTest("b/112166045")
with ops.Graph().as_default() as g:
with self.test_session(graph=g):
def fn():
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
index a0d5557b92..cc788219ef 100644
--- a/tensorflow/python/kernel_tests/constant_op_eager_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -523,7 +523,7 @@ class OnesLikeTest(test.TestCase):
class FillTest(test.TestCase):
def _compare(self, dims, val, np_ans, use_gpu):
- ctx = context.get_default_context()
+ ctx = context.context()
device = "GPU:0" if (use_gpu and ctx.num_gpus()) else "CPU:0"
with ops.device(device):
tf_ans = array_ops.fill(dims, val, name="fill")
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 5e0447e4ff..eac97af4ed 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -41,6 +41,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import cond_v2 # pylint: disable=unused-import
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
@@ -333,7 +334,7 @@ class ControlFlowTest(test.TestCase):
def testCondBool(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113296297")
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
@@ -384,7 +385,7 @@ class ControlFlowTest(test.TestCase):
def testCondIndexedSlices(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113296180")
with self.test_session():
values = constant_op.constant(10)
@@ -402,7 +403,7 @@ class ControlFlowTest(test.TestCase):
def testCondSparseTensor(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113296161 (SparseTensors)")
with self.test_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -422,7 +423,7 @@ class ControlFlowTest(test.TestCase):
def testCondResource(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
rv = resource_variable_ops.ResourceVariable(True)
@@ -438,7 +439,7 @@ class ControlFlowTest(test.TestCase):
def testCondIndexedSlicesDifferentTypes(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113293074")
with self.test_session():
values = constant_op.constant(10)
@@ -484,14 +485,14 @@ class ControlFlowTest(test.TestCase):
def testCond_1(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
self._testCond_1(use_gpu=False)
self._testCond_1(use_gpu=True)
def testCond_2(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
x = constant_op.constant(10)
@@ -503,7 +504,7 @@ class ControlFlowTest(test.TestCase):
def testCond_3(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
x = constant_op.constant(10)
@@ -518,7 +519,7 @@ class ControlFlowTest(test.TestCase):
def testCond_4(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113324949 (ref vars)")
with self.test_session():
v1 = variables.Variable(7)
@@ -541,9 +542,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(7, v3.eval())
def testCond_5(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
-
with self.test_session():
alive = constant_op.constant(True, name="alive")
count = constant_op.constant(0, name="count")
@@ -559,7 +557,7 @@ class ControlFlowTest(test.TestCase):
def testCond_6(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
v1 = variables.Variable([7])
@@ -586,7 +584,7 @@ class ControlFlowTest(test.TestCase):
def testCondRef(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
x = gen_state_ops.variable(
@@ -602,7 +600,7 @@ class ControlFlowTest(test.TestCase):
def testCondWithControl(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/79881896")
with self.test_session() as sess:
control_holder = array_ops.placeholder(dtypes.float32, shape=())
@@ -644,7 +642,7 @@ class ControlFlowTest(test.TestCase):
def testCondSwitchIdentity(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
@@ -661,7 +659,7 @@ class ControlFlowTest(test.TestCase):
def testCondRecvIdentity(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/112477618 (Operation returned from cond)")
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
@@ -680,7 +678,7 @@ class ControlFlowTest(test.TestCase):
def testCondGrad_1(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113346829 (gpu failure)")
graph = ops.Graph()
with graph.as_default():
@@ -693,18 +691,8 @@ class ControlFlowTest(test.TestCase):
grad = gradients_impl.gradients(r, [x])[0]
with self.test_session():
self.assertAllEqual(1.0, grad.eval())
- # The gradients computation creates a tensor with zeros by broadcasting a
- # zeros constant to the required shape. Verify that the zero constant
- # feeding into the fill is dominated by a Switch.
- zero = graph.get_operation_by_name("gradients/zeros/Const")
- self.assertEqual(len(zero.control_inputs), 1)
- self.assertEqual(zero.control_inputs[0].type, "Identity")
- self.assertEqual(zero.control_inputs[0].inputs[0].op.type, "Switch")
def testCondGrad_2(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
-
with self.test_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
x = constant_op.constant(10.0)
@@ -719,7 +707,7 @@ class ControlFlowTest(test.TestCase):
def testCondGrad_3(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.test_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -738,9 +726,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
def testNestedCond_Simple(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
-
with self.test_session():
x = constant_op.constant(0., name="X")
y = control_flow_ops.cond(
@@ -757,7 +742,7 @@ class ControlFlowTest(test.TestCase):
def testCondGrad_Gather(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113327884")
with self.test_session() as sess:
v1 = variables.Variable([1.0, 42.0])
@@ -932,7 +917,7 @@ class ControlFlowTest(test.TestCase):
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294340 (enable while_v2)")
v = constant_op.constant(1.0)
@@ -1391,7 +1376,7 @@ class ControlFlowTest(test.TestCase):
def testWhileCondWithControl(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294377 (unknown shape)")
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
@@ -1408,7 +1393,7 @@ class ControlFlowTest(test.TestCase):
def testWhileCondWithControl_1(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113324949 (ref vars)")
with self.test_session():
v = variable_scope.get_variable(
@@ -1433,7 +1418,7 @@ class ControlFlowTest(test.TestCase):
def testWhileCondExitControl(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294340 (enable while_v2)")
with self.test_session():
v = variables.Variable(1)
@@ -1459,7 +1444,7 @@ class ControlFlowTest(test.TestCase):
def testCondWhile_1(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1472,7 +1457,7 @@ class ControlFlowTest(test.TestCase):
def testCondWhile_2(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
n = ops.convert_to_tensor(0)
@@ -1485,7 +1470,7 @@ class ControlFlowTest(test.TestCase):
def _testCondWhile_3(self, use_gpu):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294340 (enable while_v2)")
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
@@ -1514,7 +1499,7 @@ class ControlFlowTest(test.TestCase):
def testWhileCond_1(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294377 (unknown shape)")
with self.test_session():
i = ops.convert_to_tensor(0, name="i")
@@ -1532,7 +1517,7 @@ class ControlFlowTest(test.TestCase):
def testWhileCond_2(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294377 (unknown shape)")
with self.test_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1543,7 +1528,7 @@ class ControlFlowTest(test.TestCase):
def testWhileCond_3(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294377 (unknown shape)")
with self.test_session():
n = ops.convert_to_tensor(0)
@@ -1806,9 +1791,6 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_ColocateGradients(colocate=True)
def testWhileGrad_Square(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
-
with self.test_session():
v = constant_op.constant(2.0, name="v")
c = lambda v: math_ops.less(v, 100.0)
@@ -1891,7 +1873,7 @@ class ControlFlowTest(test.TestCase):
def _testNestedWhileCondWhileGrad(self, use_gpu):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1932,7 +1914,7 @@ class ControlFlowTest(test.TestCase):
def testWhileGradInCond(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.test_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1983,7 +1965,7 @@ class ControlFlowTest(test.TestCase):
def testCondGradInNestedWhiles(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
@@ -2299,15 +2281,12 @@ class ControlFlowTest(test.TestCase):
def testWhileCondGrad_Simple(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113294377 (unknown shape)")
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
def testWhileCondGrad_UnknownShape(self):
- if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
-
with self.test_session() as sess:
v = array_ops.placeholder(dtypes.float32)
n = ops.convert_to_tensor(100.0, name="n")
@@ -2655,7 +2634,7 @@ class ControlFlowTest(test.TestCase):
def testOneValueCond(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -2673,7 +2652,7 @@ class ControlFlowTest(test.TestCase):
def testExampleCond(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/111124878 (don't return tuple)")
with self.test_session():
x = ops.convert_to_tensor([-2.0, 2.0], name="x")
@@ -2691,7 +2670,7 @@ class ControlFlowTest(test.TestCase):
def testCase(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/112477618 (Operation returned from cond)")
with self.test_session():
x = constant_op.constant(1)
@@ -2746,7 +2725,7 @@ class ControlFlowTest(test.TestCase):
def testCaseSideEffects(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/112477618 (Operation returned from cond)")
with self.test_session() as sess:
v0 = variables.Variable(-1)
@@ -2784,7 +2763,7 @@ class ControlFlowTest(test.TestCase):
def testOneOpCond(self):
if control_flow_ops._ENABLE_COND_V2:
- return unittest.skip("disabled when using cond_v2")
+ return unittest.skip("b/113324949 (ref vars)")
with self.test_session():
v = variables.Variable(0)
diff --git a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
index e1920eb568..41ae0b456f 100644
--- a/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
+++ b/tensorflow/python/kernel_tests/ctc_decoder_ops_test.py
@@ -188,11 +188,11 @@ class CTCGreedyDecoderTest(test.TestCase):
],
dtype=np.float32)
# Add arbitrary offset - this is fine
- input_log_prob_matrix_0 = np.log(input_prob_matrix_0) + 2.0
+ input_prob_matrix_0 = input_prob_matrix_0 + 2.0
# len max_time_steps array of batch_size x depth matrices
inputs = ([
- input_log_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
+ input_prob_matrix_0[t, :][np.newaxis, :] for t in range(seq_len_0)
] # Pad to max_time_steps = 8
+ 2 * [np.zeros(
(1, depth), dtype=np.float32)])
@@ -200,11 +200,11 @@ class CTCGreedyDecoderTest(test.TestCase):
# batch_size length vector of sequence_lengths
seq_lens = np.array([seq_len_0], dtype=np.int32)
- # batch_size length vector of negative log probabilities
+ # batch_size length vector of log probabilities
log_prob_truth = np.array(
[
- 0.584855, # output beam 0
- 0.389139 # output beam 1
+ -5.811451, # output beam 0
+ -6.63339 # output beam 1
],
np.float32)[np.newaxis, :]
@@ -215,11 +215,11 @@ class CTCGreedyDecoderTest(test.TestCase):
[[0, 0], [0, 1]], dtype=np.int64), np.array(
[1, 0], dtype=np.int64), np.array(
[1, 2], dtype=np.int64)),
- # beam 1, batch 0, three outputs decoded
+ # beam 1, batch 0, one output decoded
(np.array(
- [[0, 0], [0, 1], [0, 2]], dtype=np.int64), np.array(
- [0, 1, 0], dtype=np.int64), np.array(
- [1, 3], dtype=np.int64)),
+ [[0, 0]], dtype=np.int64), np.array(
+ [1], dtype=np.int64), np.array(
+ [1, 1], dtype=np.int64)),
]
# Test correct decoding.
diff --git a/tensorflow/python/kernel_tests/distributions/categorical_test.py b/tensorflow/python/kernel_tests/distributions/categorical_test.py
index d8939433ce..c6bb06eab3 100644
--- a/tensorflow/python/kernel_tests/distributions/categorical_test.py
+++ b/tensorflow/python/kernel_tests/distributions/categorical_test.py
@@ -47,7 +47,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
def testP(self):
p = [0.2, 0.8]
dist = categorical.Categorical(probs=p)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(p, dist.probs.eval())
self.assertAllEqual([2], dist.logits.get_shape())
@@ -55,14 +55,14 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
p = np.array([0.2, 0.8], dtype=np.float32)
logits = np.log(p) - 50.
dist = categorical.Categorical(logits=logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([2], dist.probs.get_shape())
self.assertAllEqual([2], dist.logits.get_shape())
self.assertAllClose(dist.probs.eval(), p)
self.assertAllClose(dist.logits.eval(), logits)
def testShapes(self):
- with self.test_session():
+ with self.cached_session():
for batch_shape in ([], [1], [2, 3, 4]):
dist = make_categorical(batch_shape, 10)
self.assertAllEqual(batch_shape, dist.batch_shape)
@@ -108,7 +108,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
self.assertEqual(dist.dtype, dist.sample(5).dtype)
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
logits = array_ops.placeholder(dtype=dtypes.float32)
dist = categorical.Categorical(logits)
sample = dist.sample()
@@ -124,13 +124,13 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
def testPMFWithBatch(self):
histograms = [[0.2, 0.8], [0.6, 0.4]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.prob([0, 1]).eval(), [0.2, 0.4])
def testPMFNoBatch(self):
histograms = [0.2, 0.8]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.prob(0).eval(), 0.2)
def testCDFWithDynamicEventShapeKnownNdims(self):
@@ -162,7 +162,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
event: event_feed_two
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual_cdf_one = sess.run(cdf_op, feed_dict=feed_dict_one)
actual_cdf_two = sess.run(cdf_op, feed_dict=feed_dict_two)
@@ -192,7 +192,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
dist = categorical.Categorical(probs=histograms)
cdf_op = dist.cdf(event)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(cdf_op.eval(), expected_cdf)
def testCDFNoBatch(self):
@@ -202,7 +202,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
dist = categorical.Categorical(probs=histogram)
cdf_op = dist.cdf(event)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(cdf_op.eval(), expected_cdf)
def testCDFBroadcasting(self):
@@ -228,7 +228,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
expected_cdf_result[2, 0] = 0.3
expected_cdf_result[2, 1] = 0.75
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.cdf(devent).eval(), expected_cdf_result)
def testBroadcastWithBatchParamsAndBiggerEvent(self):
@@ -286,7 +286,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
"norm_log_cdf": norm.log_cdf(real_event_tf),
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_result = sess.run(to_run)
self.assertAllEqual(run_result["cat_prob"].shape,
@@ -301,28 +301,28 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
def testLogPMF(self):
logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
dist = categorical.Categorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.log_prob([0, 1]).eval(), np.log([0.2, 0.4]))
self.assertAllClose(dist.log_prob([0.0, 1.0]).eval(), np.log([0.2, 0.4]))
def testEntropyNoBatch(self):
logits = np.log([0.2, 0.8]) - 50.
dist = categorical.Categorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.entropy().eval(),
-(0.2 * np.log(0.2) + 0.8 * np.log(0.8)))
def testEntropyWithBatch(self):
logits = np.log([[0.2, 0.8], [0.6, 0.4]]) - 50.
dist = categorical.Categorical(logits)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(dist.entropy().eval(), [
-(0.2 * np.log(0.2) + 0.8 * np.log(0.8)),
-(0.6 * np.log(0.6) + 0.4 * np.log(0.4))
])
def testEntropyGradient(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
logits = constant_op.constant([[1., 2., 3.], [2., 5., 1.]])
probabilities = nn_ops.softmax(logits)
@@ -348,7 +348,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
res["categorical_entropy_g"])
def testSample(self):
- with self.test_session():
+ with self.cached_session():
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
n = 10000
@@ -366,7 +366,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
sample_values == 1, axis=0), atol=1e-2)
def testSampleWithSampleShape(self):
- with self.test_session():
+ with self.cached_session():
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
samples = dist.sample((100, 100), seed=123)
@@ -387,7 +387,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
self.assertIsNone(grad_p)
def testLogPMFBroadcasting(self):
- with self.test_session():
+ with self.cached_session():
# 1 x 2 x 2
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
@@ -415,7 +415,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
prob.eval())
def testLogPMFShape(self):
- with self.test_session():
+ with self.cached_session():
# shape [1, 2, 2]
histograms = [[[0.2, 0.8], [0.4, 0.6]]]
dist = categorical.Categorical(math_ops.log(histograms))
@@ -441,7 +441,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([2, 2, 2], log_prob.get_shape())
def testMode(self):
- with self.test_session():
+ with self.cached_session():
histograms = [[[0.2, 0.8], [0.6, 0.4]]]
dist = categorical.Categorical(math_ops.log(histograms) - 50.)
self.assertAllEqual(dist.mode().eval(), [[1, 0]])
@@ -452,7 +452,7 @@ class CategoricalTest(test.TestCase, parameterized.TestCase):
exp_logits = np.exp(logits)
return exp_logits / exp_logits.sum(axis=-1, keepdims=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for categories in [2, 4]:
for batch_size in [1, 10]:
a_logits = np.random.randn(batch_size, categories)
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
index 1b9edcc85a..d558ca09cc 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_multinomial_test.py
@@ -37,7 +37,7 @@ class DirichletMultinomialTest(test.TestCase):
self._rng = np.random.RandomState(42)
def testSimpleShapes(self):
- with self.test_session():
+ with self.cached_session():
alpha = np.random.rand(3)
dist = ds.DirichletMultinomial(1., alpha)
self.assertEqual(3, dist.event_shape_tensor().eval())
@@ -46,7 +46,7 @@ class DirichletMultinomialTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
+ with self.cached_session():
alpha = np.random.rand(3, 2, 2)
n = [[3., 2], [4, 5], [6, 7]]
dist = ds.DirichletMultinomial(n, alpha)
@@ -58,14 +58,14 @@ class DirichletMultinomialTest(test.TestCase):
def testNproperty(self):
alpha = [[1., 2, 3]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(n, alpha)
self.assertEqual([1, 1], dist.total_count.get_shape())
self.assertAllClose(n, dist.total_count.eval())
def testAlphaProperty(self):
alpha = [[1., 2, 3]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(1, alpha)
self.assertEqual([1, 3], dist.concentration.get_shape())
self.assertAllClose(alpha, dist.concentration.eval())
@@ -73,7 +73,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfNandCountsAgree(self):
alpha = [[1., 2, 3]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval()
@@ -86,7 +86,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfNonIntegerCounts(self):
alpha = [[1., 2, 3]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(n, alpha, validate_args=True)
dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval()
@@ -104,7 +104,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfBothZeroBatches(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [1., 0]
@@ -116,7 +116,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfBothZeroBatchesNontrivialN(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
alpha = [1., 2]
counts = [3., 2]
@@ -128,7 +128,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfBothZeroBatchesMultidimensionalN(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [1., 2]
counts = [3., 2]
n = np.full([4, 3], 5., dtype=np.float32)
@@ -140,7 +140,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfAlphaStretchedInBroadcastWhenSameRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [[1., 2]]
counts = [[1., 0], [0., 1]]
dist = ds.DirichletMultinomial([1.], alpha)
@@ -151,7 +151,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfAlphaStretchedInBroadcastWhenLowerRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [1., 2]
counts = [[1., 0], [0., 1]]
pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
@@ -161,7 +161,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [[1., 2], [2., 3]]
counts = [[1., 0]]
pmf = ds.DirichletMultinomial([1., 1.], alpha).prob(counts)
@@ -171,7 +171,7 @@ class DirichletMultinomialTest(test.TestCase):
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
# The probabilities of one vote falling into class k is the mean for class
# k.
- with self.test_session():
+ with self.cached_session():
alpha = [[1., 2], [2., 3]]
counts = [1., 0]
pmf = ds.DirichletMultinomial(1., alpha).prob(counts)
@@ -182,7 +182,7 @@ class DirichletMultinomialTest(test.TestCase):
# The probabilities of one vote falling into class k is the mean for class
# k.
alpha = [1., 2, 3]
- with self.test_session():
+ with self.cached_session():
for class_num in range(3):
counts = np.zeros([3], dtype=np.float32)
counts[class_num] = 1
@@ -199,7 +199,7 @@ class DirichletMultinomialTest(test.TestCase):
# DirichletMultinomial(2, alpha) is twice as much as the probability of one
# vote falling into class k for DirichletMultinomial(1, alpha)
alpha = [1., 2, 3]
- with self.test_session():
+ with self.cached_session():
for class_num in range(3):
counts_one = np.zeros([3], dtype=np.float32)
counts_one[class_num] = 1.
@@ -223,7 +223,7 @@ class DirichletMultinomialTest(test.TestCase):
# Ideally we'd be able to test broadcasting but, the multinomial sampler
# doesn't support different total counts.
n = np.float32(5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# batch_shape=[2], event_shape=[3]
dist = ds.DirichletMultinomial(n, alpha)
x = dist.sample(int(250e3), seed=1)
@@ -281,7 +281,7 @@ class DirichletMultinomialTest(test.TestCase):
variance_entry(alpha[1], alpha_0)
]])
- with self.test_session():
+ with self.cached_session():
for n in ns:
# n is shape [] and alpha is shape [2].
dist = ds.DirichletMultinomial(n, alpha)
@@ -319,7 +319,7 @@ class DirichletMultinomialTest(test.TestCase):
]]],
dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
# ns is shape [4, 1], and alpha is shape [4, 3].
dist = ds.DirichletMultinomial(ns, alpha)
covariance = dist.covariance()
@@ -336,7 +336,7 @@ class DirichletMultinomialTest(test.TestCase):
ns = np.random.randint(low=1, high=11, size=[3, 5, 1]).astype(np.float32)
ns2 = np.random.randint(low=1, high=11, size=[6, 1, 1]).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(ns, alpha)
dist2 = ds.DirichletMultinomial(ns2, alpha2)
@@ -350,7 +350,7 @@ class DirichletMultinomialTest(test.TestCase):
# probability 1.
alpha = [5, 0.5]
counts = [0., 0]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(0., alpha)
pmf = dist.prob(counts)
self.assertAllClose(1.0, pmf.eval())
@@ -365,7 +365,7 @@ class DirichletMultinomialTest(test.TestCase):
# One (three sided) coin flip. Prob[coin 3] = 0.8.
# Note that since it was one flip, value of tau didn't matter.
counts = [0., 0, 1]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(1., alpha)
pmf = dist.prob(counts)
self.assertAllClose(0.8, pmf.eval(), atol=1e-4)
@@ -373,7 +373,7 @@ class DirichletMultinomialTest(test.TestCase):
# Two (three sided) coin flips. Prob[coin 3] = 0.8.
counts = [0., 0, 2]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(2., alpha)
pmf = dist.prob(counts)
self.assertAllClose(0.8**2, pmf.eval(), atol=1e-2)
@@ -381,7 +381,7 @@ class DirichletMultinomialTest(test.TestCase):
# Three (three sided) coin flips.
counts = [1., 0, 2]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(3., alpha)
pmf = dist.prob(counts)
self.assertAllClose(3 * 0.1 * 0.8 * 0.8, pmf.eval(), atol=1e-2)
@@ -396,7 +396,7 @@ class DirichletMultinomialTest(test.TestCase):
# If there is only one draw, it is still a coin flip, even with small tau.
counts = [1., 0]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(1., alpha)
pmf = dist.prob(counts)
self.assertAllClose(0.5, pmf.eval())
@@ -405,7 +405,7 @@ class DirichletMultinomialTest(test.TestCase):
# If there are two draws, it is much more likely that they are the same.
counts_same = [2., 0]
counts_different = [1, 1.]
- with self.test_session():
+ with self.cached_session():
dist = ds.DirichletMultinomial(2., alpha)
pmf_same = dist.prob(counts_same)
pmf_different = dist.prob(counts_different)
@@ -414,7 +414,7 @@ class DirichletMultinomialTest(test.TestCase):
def testNonStrictTurnsOffAllChecks(self):
# Make totally invalid input.
- with self.test_session():
+ with self.cached_session():
alpha = [[-1., 2]] # alpha should be positive.
counts = [[1., 0], [0., -1]] # counts should be non-negative.
n = [-5.3] # n should be a non negative integer equal to counts.sum.
@@ -422,7 +422,7 @@ class DirichletMultinomialTest(test.TestCase):
dist.prob(counts).eval() # Should not raise.
def testSampleUnbiasedNonScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = ds.DirichletMultinomial(
total_count=5.,
concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32))
@@ -451,7 +451,7 @@ class DirichletMultinomialTest(test.TestCase):
actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
def testSampleUnbiasedScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = ds.DirichletMultinomial(
total_count=5.,
concentration=1. + 2. * self._rng.rand(4).astype(np.float32))
diff --git a/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
index b347c20db2..e35a8e1cdd 100644
--- a/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
+++ b/tensorflow/python/kernel_tests/distributions/identity_bijector_test.py
@@ -42,7 +42,7 @@ class IdentityBijectorTest(test.TestCase):
bijector.forward_log_det_jacobian(x, event_ndims=3)))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = identity_bijector.Identity()
bijector_test_util.assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
diff --git a/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py b/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py
index d0fa1fe989..e77e1117d4 100644
--- a/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py
+++ b/tensorflow/python/kernel_tests/distributions/kullback_leibler_test.py
@@ -58,7 +58,7 @@ class KLTest(test.TestCase):
# pylint: disable=unused-argument,unused-variable
- with self.test_session():
+ with self.cached_session():
a = MyDistException(loc=0.0, scale=1.0, allow_nan_stats=False)
kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False)
with self.assertRaisesOpError(
diff --git a/tensorflow/python/kernel_tests/distributions/multinomial_test.py b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
index bfd40ba2b7..3840d7331c 100644
--- a/tensorflow/python/kernel_tests/distributions/multinomial_test.py
+++ b/tensorflow/python/kernel_tests/distributions/multinomial_test.py
@@ -34,7 +34,7 @@ class MultinomialTest(test.TestCase):
self._rng = np.random.RandomState(42)
def testSimpleShapes(self):
- with self.test_session():
+ with self.cached_session():
p = [.1, .3, .6]
dist = multinomial.Multinomial(total_count=1., probs=p)
self.assertEqual(3, dist.event_shape_tensor().eval())
@@ -43,7 +43,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual(tensor_shape.TensorShape([]), dist.batch_shape)
def testComplexShapes(self):
- with self.test_session():
+ with self.cached_session():
p = 0.5 * np.ones([3, 2, 2], dtype=np.float32)
n = [[3., 2], [4, 5], [6, 7]]
dist = multinomial.Multinomial(total_count=n, probs=p)
@@ -55,14 +55,14 @@ class MultinomialTest(test.TestCase):
def testN(self):
p = [[0.1, 0.2, 0.7], [0.2, 0.3, 0.5]]
n = [[3.], [4]]
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=n, probs=p)
self.assertEqual((2, 1), dist.total_count.get_shape())
self.assertAllClose(n, dist.total_count.eval())
def testP(self):
p = [[0.1, 0.2, 0.7]]
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=3., probs=p)
self.assertEqual((1, 3), dist.probs.get_shape())
self.assertEqual((1, 3), dist.logits.get_shape())
@@ -71,7 +71,7 @@ class MultinomialTest(test.TestCase):
def testLogits(self):
p = np.array([[0.1, 0.2, 0.7]], dtype=np.float32)
logits = np.log(p) - 50.
- with self.test_session():
+ with self.cached_session():
multinom = multinomial.Multinomial(total_count=3., logits=logits)
self.assertEqual((1, 3), multinom.probs.get_shape())
self.assertEqual((1, 3), multinom.logits.get_shape())
@@ -80,7 +80,7 @@ class MultinomialTest(test.TestCase):
def testPmfUnderflow(self):
logits = np.array([[-200, 0]], dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=1., logits=logits)
lp = dist.log_prob([1., 0.]).eval()[0]
self.assertAllClose(-200, lp, atol=0, rtol=1e-6)
@@ -88,7 +88,7 @@ class MultinomialTest(test.TestCase):
def testPmfandCountsAgree(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(total_count=n, probs=p, validate_args=True)
dist.prob([2., 3, 0]).eval()
dist.prob([3., 0, 2]).eval()
@@ -100,7 +100,7 @@ class MultinomialTest(test.TestCase):
def testPmfNonIntegerCounts(self):
p = [[0.1, 0.2, 0.7]]
n = [[5.]]
- with self.test_session():
+ with self.cached_session():
# No errors with integer n.
multinom = multinomial.Multinomial(
total_count=n, probs=p, validate_args=True)
@@ -122,7 +122,7 @@ class MultinomialTest(test.TestCase):
multinom.prob([1.0, 2.5, 1.5]).eval()
def testPmfBothZeroBatches(self):
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
p = [0.5, 0.5]
counts = [1., 0]
@@ -131,7 +131,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((), pmf.get_shape())
def testPmfBothZeroBatchesNontrivialN(self):
- with self.test_session():
+ with self.cached_session():
# Both zero-batches. No broadcast
p = [0.1, 0.9]
counts = [3., 2]
@@ -142,7 +142,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
+ with self.cached_session():
p = [[0.1, 0.9]]
counts = [[1., 0], [0, 1]]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -150,7 +150,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((2), pmf.get_shape())
def testPmfPStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
+ with self.cached_session():
p = [0.1, 0.9]
counts = [[1., 0], [0, 1]]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -158,7 +158,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenSameRank(self):
- with self.test_session():
+ with self.cached_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [[1., 0]]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -166,7 +166,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((2), pmf.get_shape())
def testPmfCountsStretchedInBroadcastWhenLowerRank(self):
- with self.test_session():
+ with self.cached_session():
p = [[0.1, 0.9], [0.7, 0.3]]
counts = [1., 0]
pmf = multinomial.Multinomial(total_count=1., probs=p).prob(counts)
@@ -174,7 +174,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual(pmf.get_shape(), (2))
def testPmfShapeCountsStretchedN(self):
- with self.test_session():
+ with self.cached_session():
# [2, 2, 2]
p = [[[0.1, 0.9], [0.1, 0.9]], [[0.7, 0.3], [0.7, 0.3]]]
# [2, 2]
@@ -186,7 +186,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual(pmf.get_shape(), (2, 2))
def testPmfShapeCountsPStretchedN(self):
- with self.test_session():
+ with self.cached_session():
p = [0.1, 0.9]
counts = [3., 2]
n = np.full([4, 3], 5., dtype=np.float32)
@@ -195,7 +195,7 @@ class MultinomialTest(test.TestCase):
self.assertEqual((4, 3), pmf.get_shape())
def testMultinomialMean(self):
- with self.test_session():
+ with self.cached_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = multinomial.Multinomial(total_count=n, probs=p)
@@ -204,7 +204,7 @@ class MultinomialTest(test.TestCase):
self.assertAllClose(expected_means, dist.mean().eval())
def testMultinomialCovariance(self):
- with self.test_session():
+ with self.cached_session():
n = 5.
p = [0.1, 0.2, 0.7]
dist = multinomial.Multinomial(total_count=n, probs=p)
@@ -215,7 +215,7 @@ class MultinomialTest(test.TestCase):
self.assertAllClose(expected_covariances, dist.covariance().eval())
def testMultinomialCovarianceBatch(self):
- with self.test_session():
+ with self.cached_session():
# Shape [2]
n = [5.] * 2
# Shape [4, 1, 2]
@@ -237,7 +237,7 @@ class MultinomialTest(test.TestCase):
ns = np.random.randint(low=1, high=11, size=[3, 5]).astype(np.float32)
ns2 = np.random.randint(low=1, high=11, size=[6, 1]).astype(np.float32)
- with self.test_session():
+ with self.cached_session():
dist = multinomial.Multinomial(ns, p)
dist2 = multinomial.Multinomial(ns2, p2)
@@ -253,7 +253,7 @@ class MultinomialTest(test.TestCase):
[2.5, 4, 0.01]], dtype=np.float32)
theta /= np.sum(theta, 1)[..., array_ops.newaxis]
n = np.array([[10., 9.], [8., 7.], [6., 5.]], dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# batch_shape=[3, 2], event_shape=[3]
dist = multinomial.Multinomial(n, theta)
x = dist.sample(int(1000e3), seed=1)
@@ -289,7 +289,7 @@ class MultinomialTest(test.TestCase):
self.assertAllClose(sample_stddev_, analytic_stddev, atol=0.01, rtol=0.01)
def testSampleUnbiasedNonScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = multinomial.Multinomial(
total_count=[7., 6., 5.],
logits=math_ops.log(2. * self._rng.rand(4, 3, 2).astype(np.float32)))
@@ -318,7 +318,7 @@ class MultinomialTest(test.TestCase):
actual_covariance_, sample_covariance_, atol=0., rtol=0.20)
def testSampleUnbiasedScalarBatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dist = multinomial.Multinomial(
total_count=5.,
logits=math_ops.log(2. * self._rng.rand(4).astype(np.float32)))
diff --git a/tensorflow/python/kernel_tests/embedding_ops_test.py b/tensorflow/python/kernel_tests/embedding_ops_test.py
index 55d75cb474..dcd435e1ff 100644
--- a/tensorflow/python/kernel_tests/embedding_ops_test.py
+++ b/tensorflow/python/kernel_tests/embedding_ops_test.py
@@ -480,7 +480,7 @@ class EmbeddingLookupTest(test.TestCase):
id_vals, shape=ids_shape, dtype=dtypes.int32)
x, params, _ = _EmbeddingParams(num_shards, vocab_size, shape=[2])
y = embedding_ops.embedding_lookup(x, ids)
- y_shape = [num_ids] + list(params[_PName(0) + ":0"].shape[1:])
+ y_shape = ids_shape + tuple(params[_PName(0) + ":0"].shape[1:])
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
@@ -663,8 +663,9 @@ class EmbeddingLookupSparseTest(test.TestCase):
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
for num_shards, combiner, dtype, ignore_weights in itertools.product(
- [1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
- dtypes.float64], [True, False]):
+ [1, 5], ["sum", "mean", "sqrtn"],
+ [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64],
+ [True, False]):
with self.test_session():
p, params, feed_dict = _EmbeddingParams(
@@ -677,6 +678,10 @@ class EmbeddingLookupSparseTest(test.TestCase):
self.assertEqual(embedding_sum.get_shape().as_list(),
expected_lookup_result_shape)
+ if dtype in (dtypes.float16, dtypes.bfloat16):
+ self.assertEqual(embedding_sum.dtype, dtypes.float32)
+ else:
+ self.assertEqual(embedding_sum.dtype, dtype)
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
@@ -692,7 +697,14 @@ class EmbeddingLookupSparseTest(test.TestCase):
if combiner == "sqrtn":
np_embedding_sum /= np.reshape(
np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
- self.assertAllClose(np_embedding_sum, tf_embedding_sum)
+
+ rtol = 1e-6
+ if dtype == dtypes.bfloat16:
+ rtol = 1e-2
+ elif dtype == dtypes.float16:
+ rtol = 1e-3
+ atol = rtol
+ self.assertAllClose(np_embedding_sum, tf_embedding_sum, rtol, atol)
def testGradientsEmbeddingLookupSparse(self):
vocab_size = 12
diff --git a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
index 60090a1510..e1f5a6b620 100644
--- a/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
+++ b/tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
@@ -25,6 +25,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed as random_seed_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
@@ -100,6 +102,24 @@ class ExtractImagePatchesGradTest(test.TestCase):
print('extract_image_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)
+ def testConstructGradientWithLargeImages(self):
+ batch_size = 4
+ height = 1024
+ width = 1024
+ ksize = 5
+ images = variable_scope.get_variable('inputs',
+ (batch_size, height, width, 1))
+ patches = array_ops.extract_image_patches(images,
+ ksizes=[1, ksize, ksize, 1],
+ strides=[1, 1, 1, 1],
+ rates=[1, 1, 1, 1],
+ padding='SAME')
+ # Github issue: #20146
+ # tf.extract_image_patches() gradient very slow at graph construction time
+ gradients = gradients_impl.gradients(patches, images)
+ # Won't time out.
+ self.assertIsNotNone(gradients)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 5db2e9821d..1e76ad7476 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -1075,30 +1075,13 @@ class PartitionedCallTest(test.TestCase):
with ops.device("/cpu:2"):
s3 = iterator_ops.Iterator.from_structure(
(dtypes.float32,)).string_handle()
- with ops.device(""):
- # TODO(akshayka): This is unfortunate and brittle. It prevents
- # `Iterator.from_structure` from assigning the iterator op to 'cpu:0'.
- # Remove this hack once we have a way of obtaining metadata about
- # function execution.
- s4 = iterator_ops.Iterator.from_structure(
- (dtypes.float32,)).string_handle()
- return s1, s2, s3, s4
+ return s1, s2, s3
with self.test_session(config=config, use_gpu=True) as sess:
- with ops.device("/cpu:3"):
- outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
- self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
- self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
- self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
- self.assertIn(compat.as_bytes("CPU:3"), outputs[3])
-
- with self.test_session(config=config, use_gpu=True):
- with ops.device("/cpu:0"):
- outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
+ outputs = sess.run(functional_ops.partitioned_call(args=[], f=Body))
self.assertIn(compat.as_bytes("CPU:0"), outputs[0])
self.assertIn(compat.as_bytes("CPU:1"), outputs[1])
self.assertIn(compat.as_bytes("CPU:2"), outputs[2])
- self.assertIn(compat.as_bytes("CPU:0"), outputs[3])
def testAssignAddResourceVariable(self):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
index 612a50bcec..99497914f2 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
@@ -191,7 +191,7 @@ class NonSquareLinearOperatorCompositionTest(
linalg.LinearOperatorFullMatrix(rng.rand(2, 4, 5))
]
operator = linalg.LinearOperatorComposition(operators)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((2, 3, 5), operator.shape_tensor().eval())
def test_shape_tensors_when_only_dynamically_available(self):
@@ -206,7 +206,7 @@ class NonSquareLinearOperatorCompositionTest(
linalg.LinearOperatorFullMatrix(mat_ph_2)
]
operator = linalg.LinearOperatorComposition(operators)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(
(1, 2, 3, 5), operator.shape_tensor().eval(feed_dict=feed_dict))
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
index 83cc8c483f..52861ae84a 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
@@ -52,7 +52,7 @@ class LinearOperatorDiagTest(
def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
# Matrix with one positive eigenvalue and one zero eigenvalue.
- with self.test_session():
+ with self.cached_session():
diag = [1.0, 0.0]
operator = linalg.LinearOperatorDiag(diag)
@@ -62,7 +62,7 @@ class LinearOperatorDiagTest(
operator.assert_positive_definite().run()
def test_assert_positive_definite_raises_for_negative_real_eigvalues(self):
- with self.test_session():
+ with self.cached_session():
diag_x = [1.0, -2.0]
diag_y = [0., 0.] # Imaginary eigenvalues should not matter.
diag = math_ops.complex(diag_x, diag_y)
@@ -74,7 +74,7 @@ class LinearOperatorDiagTest(
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_raise_if_pd_and_complex(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 2.]
y = [1., 0.]
diag = math_ops.complex(x, y) # Re[diag] > 0.
@@ -83,14 +83,14 @@ class LinearOperatorDiagTest(
def test_assert_non_singular_raises_if_zero_eigenvalue(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
- with self.test_session():
+ with self.cached_session():
diag = [1.0, 0.0]
operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_raise_for_complex_nonsingular(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 0.]
y = [0., 1.]
diag = math_ops.complex(x, y)
@@ -98,7 +98,7 @@ class LinearOperatorDiagTest(
linalg.LinearOperatorDiag(diag).assert_non_singular().run()
def test_assert_self_adjoint_raises_if_diag_has_complex_part(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 0.]
y = [0., 1.]
diag = math_ops.complex(x, y)
@@ -107,7 +107,7 @@ class LinearOperatorDiagTest(
operator.assert_self_adjoint().run()
def test_assert_self_adjoint_does_not_raise_for_diag_with_zero_imag(self):
- with self.test_session():
+ with self.cached_session():
x = [1., 0.]
y = [0., 0.]
diag = math_ops.complex(x, y)
@@ -123,7 +123,7 @@ class LinearOperatorDiagTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.matmul cannot handle.
# In particular, tf.matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = random_ops.random_normal(shape=(2, 2, 3, 4))
# This LinearOperatorDiag will be broadcast to (2, 2, 3, 3) during solve
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
index 1a40a29ec6..8373b5263f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
@@ -65,7 +65,7 @@ class SquareLinearOperatorFullMatrixTest(
self.assertTrue(operator.is_square)
def test_assert_non_singular_raises_if_cond_too_big_but_finite(self):
- with self.test_session():
+ with self.cached_session():
tril = linear_operator_test_util.random_tril_matrix(
shape=(50, 50), dtype=np.float32)
diag = np.logspace(-2, 2, 50).astype(np.float32)
@@ -80,7 +80,7 @@ class SquareLinearOperatorFullMatrixTest(
operator.assert_non_singular().run()
def test_assert_non_singular_raises_if_cond_infinite(self):
- with self.test_session():
+ with self.cached_session():
matrix = [[1., 1.], [1., 1.]]
# We don't pass the is_self_adjoint hint here, which means we take the
# generic code path.
@@ -91,14 +91,14 @@ class SquareLinearOperatorFullMatrixTest(
def test_assert_self_adjoint(self):
matrix = [[0., 1.], [0., 1.]]
operator = linalg.LinearOperatorFullMatrix(matrix)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("not equal to its adjoint"):
operator.assert_self_adjoint().run()
def test_assert_positive_definite(self):
matrix = [[1., 1.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Cholesky decomposition was not success"):
operator.assert_positive_definite().run()
@@ -158,7 +158,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
matrix = [[1., 1.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, is_self_adjoint=True, is_positive_definite=True)
- with self.test_session():
+ with self.cached_session():
# Cholesky decomposition may fail, so the error is not specific to
# non-singular.
with self.assertRaisesOpError(""):
@@ -168,7 +168,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
matrix = [[0., 1.], [0., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, is_self_adjoint=True, is_positive_definite=True)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("not equal to its adjoint"):
operator.assert_self_adjoint().run()
@@ -176,7 +176,7 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
matrix = [[1., 1.], [1., 1.]]
operator = linalg.LinearOperatorFullMatrix(
matrix, is_self_adjoint=True, is_positive_definite=True)
- with self.test_session():
+ with self.cached_session():
# Cholesky decomposition may fail, so the error is not specific to
# non-singular.
with self.assertRaisesOpError(""):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
index 35dcf4417c..0c3c6b390f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
@@ -57,24 +57,24 @@ class LinearOperatorIdentityTest(
return operator, mat
def test_assert_positive_definite(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_positive_definite().run() # Should not fail
def test_assert_non_singular(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_non_singular().run() # Should not fail
def test_assert_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_self_adjoint().run() # Should not fail
def test_float16_matmul(self):
# float16 cannot be tested by base test class because tf.matrix_solve does
# not work with float16.
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, dtype=dtypes.float16)
x = rng.randn(2, 3).astype(np.float16)
@@ -106,7 +106,7 @@ class LinearOperatorIdentityTest(
linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2])
def test_non_scalar_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
@@ -114,7 +114,7 @@ class LinearOperatorIdentityTest(
operator.to_dense().eval(feed_dict={num_rows: [2]})
def test_negative_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
@@ -122,7 +122,7 @@ class LinearOperatorIdentityTest(
operator.to_dense().eval(feed_dict={num_rows: -2})
def test_non_1d_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -130,7 +130,7 @@ class LinearOperatorIdentityTest(
operator.to_dense().eval(feed_dict={batch_shape: 2})
def test_negative_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -147,7 +147,7 @@ class LinearOperatorIdentityTest(
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
y = operator.matmul(x)
@@ -158,7 +158,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = random_ops.random_normal(shape=(1, 2, 3, 4))
operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
@@ -172,7 +172,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
@@ -188,7 +188,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
x = random_ops.random_normal(shape=(1, 2, 3, 4))
@@ -209,7 +209,7 @@ class LinearOperatorIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
x = array_ops.placeholder(dtypes.float32)
@@ -287,39 +287,39 @@ class LinearOperatorScaledIdentityTest(
return operator, matrix
def test_assert_positive_definite_does_not_raise_when_positive(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=1.)
operator.assert_positive_definite().run() # Should not fail
def test_assert_positive_definite_raises_when_negative(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=-1.)
with self.assertRaisesOpError("not positive definite"):
operator.assert_positive_definite().run()
def test_assert_non_singular_does_not_raise_when_non_singular(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1., 2., 3.])
operator.assert_non_singular().run() # Should not fail
def test_assert_non_singular_raises_when_singular(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1., 2., 0.])
with self.assertRaisesOpError("was singular"):
operator.assert_non_singular().run()
def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1. + 0J])
operator.assert_self_adjoint().run() # Should not fail
def test_assert_self_adjoint_raises_when_not_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1. + 1J])
with self.assertRaisesOpError("not self-adjoint"):
@@ -328,7 +328,7 @@ class LinearOperatorScaledIdentityTest(
def test_float16_matmul(self):
# float16 cannot be tested by base test class because tf.matrix_solve does
# not work with float16.
- with self.test_session():
+ with self.cached_session():
multiplier = rng.rand(3).astype(np.float16)
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=multiplier)
@@ -353,7 +353,7 @@ class LinearOperatorScaledIdentityTest(
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows, multiplier=[1., 2], assert_proper_shapes=True)
y = operator.matmul(x)
@@ -364,7 +364,7 @@ class LinearOperatorScaledIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorScaledIdentity shape of (2, 1, 3, 3), the
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
x = random_ops.random_normal(shape=(1, 2, 3, 4))
@@ -392,7 +392,7 @@ class LinearOperatorScaledIdentityTest(
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Given this x and LinearOperatorScaledIdentity shape of (3, 3), the
# broadcast shape of operator and 'x' is (1, 2, 3, 4), which is the same
# shape as x.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
index e26b946151..7e81c9c6c4 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
@@ -70,7 +70,7 @@ class KroneckerDenseTest(test.TestCase):
[10., 15., -2., -3.],
[5., 10., -1., -2.]], dtype=dtypes.float32)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(_kronecker_dense([x, y]).eval(), z.eval())
self.assertAllClose(_kronecker_dense([y, x]).eval(), w.eval())
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
index 0e38dbd48d..61268607a4 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
@@ -256,7 +256,7 @@ class LinearOpearatorLowRankUpdateBroadcastsShape(test.TestCase):
# domain_dimension is 3
self.assertAllEqual([2, 3, 3], operator.shape)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual([2, 3, 3], operator.to_dense().eval().shape)
def test_dynamic_shape_broadcasts_up_from_operator_to_other_args(self):
@@ -274,7 +274,7 @@ class LinearOpearatorLowRankUpdateBroadcastsShape(test.TestCase):
u_shape_ph: [2, 3, 2], # batch_shape = [2]
}
- with self.test_session():
+ with self.cached_session():
shape_tensor = operator.shape_tensor().eval(feed_dict=feed_dict)
self.assertAllEqual([2, 3, 3], shape_tensor)
dense = operator.to_dense().eval(feed_dict=feed_dict)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
index b389e0cbdf..eb4bff915b 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
@@ -51,7 +51,7 @@ class LinearOperatorLowerTriangularTest(
def test_assert_non_singular(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
- with self.test_session():
+ with self.cached_session():
tril = [[1., 0.], [1., 0.]]
operator = linalg.LinearOperatorLowerTriangular(tril)
with self.assertRaisesOpError("Singular operator"):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
index 8e9f0150a2..819347343b 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_test.py
@@ -108,7 +108,7 @@ class LinearOperatorTest(test.TestCase):
self.assertAllEqual(3, operator.range_dimension)
def test_all_shape_methods_defined_by_the_one_method_shape(self):
- with self.test_session():
+ with self.cached_session():
shape = (1, 2, 3, 4)
operator = LinearOperatorShape(shape)
@@ -131,7 +131,7 @@ class LinearOperatorTest(test.TestCase):
def test_generic_to_dense_method_non_square_matrix_static(self):
matrix = rng.randn(2, 3, 4)
operator = LinearOperatorMatmulSolve(matrix)
- with self.test_session():
+ with self.cached_session():
operator_dense = operator.to_dense()
self.assertAllEqual((2, 3, 4), operator_dense.get_shape())
self.assertAllClose(matrix, operator_dense.eval())
@@ -140,7 +140,7 @@ class LinearOperatorTest(test.TestCase):
matrix = rng.randn(2, 3, 4)
matrix_ph = array_ops.placeholder(dtypes.float64)
operator = LinearOperatorMatmulSolve(matrix_ph)
- with self.test_session():
+ with self.cached_session():
operator_dense = operator.to_dense()
self.assertAllClose(
matrix, operator_dense.eval(feed_dict={matrix_ph: matrix}))
@@ -149,7 +149,7 @@ class LinearOperatorTest(test.TestCase):
matrix = [[1., 0], [0., 2.]]
operator = LinearOperatorMatmulSolve(matrix)
x = [1., 1.]
- with self.test_session():
+ with self.cached_session():
y = operator.matvec(x)
self.assertAllEqual((2,), y.get_shape())
self.assertAllClose([1., 2.], y.eval())
@@ -158,7 +158,7 @@ class LinearOperatorTest(test.TestCase):
matrix = [[1., 0], [0., 2.]]
operator = LinearOperatorMatmulSolve(matrix)
y = [1., 1.]
- with self.test_session():
+ with self.cached_session():
x = operator.solvevec(y)
self.assertAllEqual((2,), x.get_shape())
self.assertAllClose([1., 1 / 2.], x.eval())
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
index 7b291e29de..86847d38c2 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_util_test.py
@@ -36,7 +36,7 @@ class AssertZeroImagPartTest(test.TestCase):
def test_real_tensor_doesnt_raise(self):
x = ops.convert_to_tensor([0., 2, 3])
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_zero_imag_part(x, message="ABC123").run()
@@ -44,7 +44,7 @@ class AssertZeroImagPartTest(test.TestCase):
x = ops.convert_to_tensor([1., 0, 3])
y = ops.convert_to_tensor([0., 0, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
@@ -52,7 +52,7 @@ class AssertZeroImagPartTest(test.TestCase):
x = ops.convert_to_tensor([1., 2, 0])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
@@ -61,7 +61,7 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
def test_nonzero_real_tensor_doesnt_raise(self):
x = ops.convert_to_tensor([1., 2, 3])
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123").run()
@@ -70,14 +70,14 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
x = ops.convert_to_tensor([1., 0, 3])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
# Should not raise.
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123").run()
def test_zero_real_tensor_raises(self):
x = ops.convert_to_tensor([1., 0, 3])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123").run()
@@ -86,7 +86,7 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
x = ops.convert_to_tensor([1., 2, 0])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123").run()
@@ -103,7 +103,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
self.assertTrue(isinstance(tensor, ops.Tensor))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(arr, tensor.eval())
def test_static_dims_broadcast(self):
@@ -118,7 +118,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
@@ -137,7 +137,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
@@ -159,7 +159,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@@ -179,7 +179,7 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@@ -203,7 +203,7 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 7)
chol_broadcast = chol + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
@@ -219,7 +219,7 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
chol_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.cholesky_solve_with_broadcast(
@@ -242,7 +242,7 @@ class MatmulWithBroadcastTest(test.TestCase):
y = rng.rand(3, 7)
y_broadcast = y + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.matmul_with_broadcast(x, y)
self.assertAllEqual((2, 1, 7), result.get_shape())
expected = math_ops.matmul(x, y_broadcast)
@@ -258,7 +258,7 @@ class MatmulWithBroadcastTest(test.TestCase):
x_ph = array_ops.placeholder(dtypes.float64)
y_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
@@ -279,7 +279,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 7)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
@@ -295,7 +295,7 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
matrix_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matrix_solve_with_broadcast(
@@ -317,7 +317,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(3, 7)
rhs_broadcast = rhs + np.zeros((2, 1, 1))
- with self.test_session():
+ with self.cached_session():
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
@@ -333,7 +333,7 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
matrix_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matrix_triangular_solve_with_broadcast(
@@ -359,7 +359,7 @@ class DomainDimensionStubOperator(object):
class AssertCompatibleMatrixDimensionsTest(test.TestCase):
def test_compatible_dimensions_do_not_raise(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(rng.rand(2, 3, 4))
operator = DomainDimensionStubOperator(3)
# Should not raise
@@ -367,7 +367,7 @@ class AssertCompatibleMatrixDimensionsTest(test.TestCase):
operator, x).run() # pyformat: disable
def test_incompatible_dimensions_raise(self):
- with self.test_session():
+ with self.cached_session():
x = ops.convert_to_tensor(rng.rand(2, 4, 4))
operator = DomainDimensionStubOperator(3)
with self.assertRaisesOpError("Incompatible matrix dimensions"):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
index 8f60b55e0a..f0556304ad 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_zeros_test.py
@@ -73,7 +73,7 @@ class LinearOperatorZerosTest(
operator.assert_non_singular()
def test_assert_self_adjoint(self):
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
operator.assert_self_adjoint().run() # Should not fail
@@ -108,7 +108,7 @@ class LinearOperatorZerosTest(
linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2])
def test_non_scalar_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True)
@@ -116,7 +116,7 @@ class LinearOperatorZerosTest(
operator.to_dense().eval(feed_dict={num_rows: [2]})
def test_negative_num_rows_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
n = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=n, assert_proper_shapes=True)
@@ -129,7 +129,7 @@ class LinearOperatorZerosTest(
operator.to_dense().eval(feed_dict={n: -2})
def test_non_1d_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -137,7 +137,7 @@ class LinearOperatorZerosTest(
operator.to_dense().eval(feed_dict={batch_shape: 2})
def test_negative_batch_shape_raises_dynamic(self):
- with self.test_session():
+ with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
@@ -154,7 +154,7 @@ class LinearOperatorZerosTest(
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True)
y = operator.matmul(x)
diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
index 24edc4f59f..723a15fbd1 100644
--- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.ops.linalg import linalg_impl
from tensorflow.python.platform import test
@@ -39,7 +40,7 @@ class LogarithmOpTest(test.TestCase):
inp = x.astype(np_type)
with self.test_session(use_gpu=True):
# Verify that expm(logm(A)) == A.
- tf_ans = gen_linalg_ops.matrix_exponential(
+ tf_ans = linalg_impl.matrix_exponential(
gen_linalg_ops.matrix_logarithm(inp))
out = tf_ans.eval()
self.assertAllClose(inp, out, rtol=1e-4, atol=1e-3)
@@ -98,16 +99,25 @@ class LogarithmOpTest(test.TestCase):
self._verifyLogarithmComplex(np.empty([0, 2, 2], dtype=np.complex64))
self._verifyLogarithmComplex(np.empty([2, 0, 0], dtype=np.complex64))
- def testRandomSmallAndLarge(self):
+ def testRandomSmallAndLargeComplex64(self):
np.random.seed(42)
- for dtype in np.complex64, np.complex128:
- for batch_dims in [(), (1,), (3,), (2, 2)]:
- for size in 8, 31, 32:
- shape = batch_dims + (size, size)
- matrix = np.random.uniform(
- low=-1.0, high=1.0,
- size=np.prod(shape)).reshape(shape).astype(dtype)
- self._verifyLogarithmComplex(matrix)
+ for batch_dims in [(), (1,), (3,), (2, 2)]:
+ for size in 8, 31, 32:
+ shape = batch_dims + (size, size)
+ matrix = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(shape)).reshape(shape).astype(np.complex64)
+ self._verifyLogarithmComplex(matrix)
+
+ def testRandomSmallAndLargeComplex128(self):
+ np.random.seed(42)
+ for batch_dims in [(), (1,), (3,), (2, 2)]:
+ for size in 8, 31, 32:
+ shape = batch_dims + (size, size)
+ matrix = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(shape)).reshape(shape).astype(np.complex128)
+ self._verifyLogarithmComplex(matrix)
def testConcurrentExecutesWithoutError(self):
with self.test_session(use_gpu=True) as sess:
diff --git a/tensorflow/python/kernel_tests/partitioned_variables_test.py b/tensorflow/python/kernel_tests/partitioned_variables_test.py
index 1d0c2dceba..15d5702252 100644
--- a/tensorflow/python/kernel_tests/partitioned_variables_test.py
+++ b/tensorflow/python/kernel_tests/partitioned_variables_test.py
@@ -27,15 +27,12 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-from tensorflow.python.training import gradient_descent
from tensorflow.python.training import saver as saver_lib
@@ -549,6 +546,32 @@ class PartitionedVariablesTestCase(test.TestCase):
partitioned_variables.create_partitioned_variables(
[10, 43], [1, 50], rnd.initialized_value())
+ def testControlDepsNone(self):
+ with self.test_session() as session:
+ c = constant_op.constant(1.0)
+ with ops.control_dependencies([c]):
+ # d get the control dependency.
+ d = constant_op.constant(2.0)
+ # Partitioned variables do not.
+ var_x = variable_scope.get_variable(
+ "x",
+ shape=[2],
+ initializer=init_ops.ones_initializer(),
+ partitioner=partitioned_variables.variable_axis_size_partitioner(4))
+
+ ops_before_read = session.graph.get_operations()
+ var_x.as_tensor() # Caches the ops for subsequent reads.
+ reading_ops = [
+ op for op in session.graph.get_operations()
+ if op not in ops_before_read
+ ]
+
+ self.assertEqual([c.op], d.op.control_inputs)
+ # Tests that no control dependencies are added to reading a partitioned
+ # variable which is similar to reading a variable.
+ for op in reading_ops:
+ self.assertEqual([], op.control_inputs)
+
def testConcat(self):
with self.test_session() as session:
var_x = variable_scope.get_variable(
@@ -574,57 +597,6 @@ class PartitionedVariablesTestCase(test.TestCase):
variables.global_variables_initializer().run()
self.assertAllClose(value.eval(), var_x.as_tensor().eval())
- def testVariableCreationInALoop(self):
- """Tests the variable created inside a loop can be used outside the loop."""
- with self.test_session():
- with variable_scope.variable_scope("ascope") as scope:
- def Body(i, _):
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(
- 4))
- return (i + 1, var_x.as_tensor())
-
- cond = lambda i, _: i < 2
- _, x = control_flow_ops.while_loop(
- cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
- variables.global_variables_initializer().run()
- self.assertAllClose([1.0, 1.0], x.eval())
-
- scope.reuse_variables()
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- self.assertAllClose([1.0, 1.0], var_x.as_tensor().eval())
-
- def testReadInWhileLoop(self):
- """Tests the value is current (not cached) when read within a loop."""
- with self.test_session():
- var_x = variable_scope.get_variable(
- "x",
- shape=[2],
- initializer=init_ops.ones_initializer(),
- partitioner=partitioned_variables.variable_axis_size_partitioner(4))
-
- def Body(i, _):
- # Use a SGD step to update the variable's value.
- loss = math_ops.reduce_sum(var_x)
- optimizer = gradient_descent.GradientDescentOptimizer(1.0)
- minimize = optimizer.minimize(loss * 0.7)
- with ops.control_dependencies([minimize]):
- return (i + 1, var_x.as_tensor())
-
- cond = lambda i, _: i < 2
- _, x = control_flow_ops.while_loop(
- cond, Body, (0, constant_op.constant([7, 8], dtypes.float32)))
- variables.global_variables_initializer().run()
- self.assertAllClose([-0.4, -0.4], x.eval())
-
def testMetaGraphSaveLoad(self):
save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
save_graph = ops.Graph()
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 25e947f09e..657d92fa23 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -23,6 +23,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
@@ -71,6 +72,35 @@ class ReluTest(test.TestCase):
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
use_gpu=True)
+ def _testReluInt8x4(self, np_inputs):
+ if not test.is_gpu_available(cuda_only=True):
+ return
+ np_relu = self._npRelu(np_inputs)
+ with self.test_session(use_gpu=True):
+ relu = nn_ops.relu(constant_op.constant(np_inputs, dtypes.qint8))
+ if np_inputs.size % 4 == 0:
+ tf_relu = relu.eval()
+ self.assertAllClose(np_relu, tf_relu)
+ self.assertShapeEqual(np_relu, relu)
+ else:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Tensor size must be a multiple of 4 for Relu<qint8>. Got %d" %
+ np_inputs.size):
+ tf_relu = relu.eval()
+
+ def testReluInt8x4GoodShape(self):
+ self._testReluInt8x4(np.array([[-50, 7, 23, 0], [-1, -5, 6, 11]]))
+
+ def testReluInt8x4BadShape(self):
+ np_inputs = np.array([[-50, 7, 23], [0, 1, -5], [6, -2, 11]])
+ self.assertEqual(np_inputs.size, 9)
+ self._testReluInt8x4(np_inputs)
+ np_inputs = np.array(
+ [1, -2, 3, -4, 5, -6, 7, -8, 9, -8, 7, -6, 5, -4, 3, -2, 1])
+ self.assertEqual(np_inputs.size, 17)
+ self._testReluInt8x4(np_inputs)
+
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat32(self):
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index f815348b2a..d0ed08933d 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import gc
+import os
+import pickle
import numpy as np
@@ -106,6 +109,27 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(False, name="bool_test")
self.assertAllEqual(bool(v), False)
+ def testEagerDeepCopy(self):
+ with context.eager_mode():
+ init_value = np.ones((4, 4, 4))
+ variable = resource_variable_ops.ResourceVariable(init_value,
+ name="init")
+
+ copied_variable = copy.deepcopy(variable)
+ copied_variable.assign(4 * np.ones((4, 4, 4)))
+
+ # Copying the variable should create a new underlying tensor with distinct
+ # values.
+ self.assertFalse(np.allclose(variable.numpy(), copied_variable.numpy()))
+
+ def testGraphDeepCopy(self):
+ with self.test_session():
+ init_value = np.ones((4, 4, 4))
+ variable = resource_variable_ops.ResourceVariable(init_value,
+ name="init")
+ with self.assertRaises(NotImplementedError):
+ copy.deepcopy(variable)
+
@test_util.run_in_graph_and_eager_modes
def testStridedSliceAssign(self):
v = resource_variable_ops.ResourceVariable([1.0, 2.0])
@@ -240,6 +264,18 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32)
self.assertEqual(self.evaluate(read), [[5]])
+ def testEagerPickle(self):
+ with context.eager_mode():
+ tmp_dir = self.get_temp_dir()
+ fname = os.path.join(tmp_dir, "var.pickle")
+ with open(fname, "wb") as f:
+ v = resource_variable_ops.ResourceVariable(10.0)
+ pickle.dump(v, f)
+
+ with open(fname, "rb") as f:
+ v = pickle.load(f)
+ self.assertAllEqual(v.numpy(), 10.0)
+
@test_util.run_in_graph_and_eager_modes
def testScatterDiv(self):
handle = resource_variable_ops.var_handle_op(
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index c4f200a22e..562d11f0b0 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -229,6 +229,13 @@ class RNNTest(test.TestCase):
self.assertAllEqual([[[1, 1], [2, 2], [3, 3], [4, 4]]], outputs[1])
self.assertAllEqual(4, state)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testEagerMemory(self):
+ with context.eager_mode():
+ cell = TensorArrayStateRNNCell()
+ inputs = np.array([[[1], [2], [3], [4]]], dtype=np.float32)
+ rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32, sequence_length=[4])
+
@test_util.run_in_graph_and_eager_modes
def testTensorArrayStateIsAccepted(self):
cell = TensorArrayStateRNNCell()
@@ -441,11 +448,11 @@ class RNNTest(test.TestCase):
cell, inputs, dtype=dtypes.float32)
self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
self.assertEqual(len(state), 4)
- self.assertEqual(state[0].shape.as_list(), [None, output_shape])
- self.assertEqual(state[1].shape.as_list(), [None, output_shape])
- self.assertEqual(state[2].shape.as_list(), [None, 2 * output_shape])
- self.assertEqual(state[3].shape.as_list(), [None, 2 * output_shape])
- loss = losses.softmax_cross_entropy(predict, state[0])
+ self.assertEqual(state[0].shape.as_list(), [None, 2 * output_shape])
+ self.assertEqual(state[1].shape.as_list(), [None, 2 * output_shape])
+ self.assertEqual(state[2].shape.as_list(), [None, output_shape])
+ self.assertEqual(state[3].shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state[2])
train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
sess.run([variables_lib.global_variables_initializer()])
diff --git a/tensorflow/python/kernel_tests/stack_op_test.py b/tensorflow/python/kernel_tests/stack_op_test.py
index 2f27d1839b..2a33c594a4 100644
--- a/tensorflow/python/kernel_tests/stack_op_test.py
+++ b/tensorflow/python/kernel_tests/stack_op_test.py
@@ -277,6 +277,18 @@ class AutomaticStackingTest(test.TestCase):
[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], dtype=dtypes.float64)
self.assertEqual(dtypes.float64, t_2.dtype)
+ t_3 = ops.convert_to_tensor(
+ [[0., 0., 0.],
+ constant_op.constant([0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]
+ ],
+ dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, t_3.dtype)
+
+ t_4 = ops.convert_to_tensor(
+ [constant_op.constant([0., 0., 0.], dtype=dtypes.float64)],
+ dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, t_4.dtype)
+
with self.assertRaises(TypeError):
ops.convert_to_tensor([
constant_op.constant(
@@ -284,17 +296,15 @@ class AutomaticStackingTest(test.TestCase):
[0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]
])
- with self.assertRaises(TypeError):
- ops.convert_to_tensor(
- [[0., 0., 0.], constant_op.constant(
- [0., 0., 0.], dtype=dtypes.float64), [0., 0., 0.]],
- dtype=dtypes.float32)
+ def testDtypeConversionWhenTensorDtypeMismatch(self):
+ t_0 = ops.convert_to_tensor([0., 0., 0.])
+ self.assertEqual(dtypes.float32, t_0.dtype)
- with self.assertRaises(TypeError):
- ops.convert_to_tensor(
- [constant_op.constant(
- [0., 0., 0.], dtype=dtypes.float64)],
- dtype=dtypes.float32)
+ t_1 = ops.convert_to_tensor([0, 0, 0])
+ self.assertEqual(dtypes.int32, t_1.dtype)
+
+ t_2 = ops.convert_to_tensor([t_0, t_0, t_1], dtype=dtypes.float64)
+ self.assertEqual(dtypes.float64, t_2.dtype)
def testPlaceholder(self):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index ae2a0ab29a..b736b12416 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -335,7 +335,7 @@ class VariableScopeTest(test.TestCase):
# reuse=True is for now only supported when eager execution is disabled.
if not context.executing_eagerly():
v = variable_scope.get_variable("v",
- []) # "v" is alredy there, reused
+ []) # "v" is already there, reused
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(3, len(losses)) # No new loss added.
@@ -389,6 +389,18 @@ class VariableScopeTest(test.TestCase):
sess.run(v0.initializer)
sess.run(add)
+ def testEnableResourceVariables(self):
+ old = variable_scope._DEFAULT_USE_RESOURCE
+ try:
+ variable_scope.enable_resource_variables()
+ self.assertTrue(isinstance(variables_lib.Variable(1.0),
+ resource_variable_ops.ResourceVariable))
+ variable_scope.disable_resource_variables()
+ self.assertFalse(isinstance(variables_lib.Variable(1.0),
+ resource_variable_ops.ResourceVariable))
+ finally:
+ variable_scope._DEFAULT_USE_RESOURCE = old
+
def testControlFlow(self):
with self.test_session() as sess:
v0 = variable_scope.get_variable(
diff --git a/tensorflow/python/lib/io/file_io.i b/tensorflow/python/lib/io/file_io.i
index 891a7b0fd0..0aa08ea3d1 100644
--- a/tensorflow/python/lib/io/file_io.i
+++ b/tensorflow/python/lib/io/file_io.i
@@ -42,7 +42,7 @@ inline void FileExists(const string& filename, TF_Status* out_status) {
inline void FileExists(const tensorflow::StringPiece& filename,
TF_Status* out_status) {
tensorflow::Status status =
- tensorflow::Env::Default()->FileExists(filename.ToString());
+ tensorflow::Env::Default()->FileExists(string(filename));
if (!status.ok()) {
Set_TF_Status_from_Status(out_status, status);
}
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index a2b5f77f91..6ae869b89e 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from math import ceil
-
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -734,7 +732,6 @@ def _QuantizeAndDequantizeV3Grad(_, grad):
@ops.RegisterGradient("ExtractImagePatches")
def _ExtractImagePatchesGrad(op, grad):
-
batch_size, rows_in, cols_in, channels = [
dim.value for dim in op.inputs[0].get_shape()
]
@@ -742,28 +739,45 @@ def _ExtractImagePatchesGrad(op, grad):
batch_size = input_bhwc[0]
channels = input_bhwc[3]
+ # Create indices matrix for input tensor.
+ # Note that 0 is preserved for padding location,
+ # so indices for input start from 1 to 1 + rows_in * cols_in.
+ input_indices_num = 1 + rows_in * cols_in
+ input_idx = array_ops.reshape(math_ops.range(1, input_indices_num,
+ dtype=ops.dtypes.int64),
+ (1, rows_in, cols_in, 1))
+ input_idx_patched = gen_array_ops.extract_image_patches(
+ input_idx,
+ op.get_attr("ksizes"),
+ op.get_attr("strides"),
+ op.get_attr("rates"),
+ op.get_attr("padding"))
+
+ # Create indices matrix for output tensor.
_, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].get_shape()]
_, ksize_r, ksize_c, _ = op.get_attr("ksizes")
- _, stride_r, stride_h, _ = op.get_attr("strides")
- _, rate_r, rate_c, _ = op.get_attr("rates")
- padding = op.get_attr("padding")
-
- ksize_r_eff = ksize_r + (ksize_r - 1) * (rate_r - 1)
- ksize_c_eff = ksize_c + (ksize_c - 1) * (rate_c - 1)
-
- if padding == b"SAME":
- rows_out = int(ceil(rows_in / stride_r))
- cols_out = int(ceil(cols_in / stride_h))
- pad_rows = ((rows_out - 1) * stride_r + ksize_r_eff - rows_in) // 2
- pad_cols = ((cols_out - 1) * stride_h + ksize_c_eff - cols_in) // 2
-
- elif padding == b"VALID":
- rows_out = int(ceil((rows_in - ksize_r_eff + 1) / stride_r))
- cols_out = int(ceil((cols_in - ksize_c_eff + 1) / stride_h))
- pad_rows = (rows_out - 1) * stride_r + ksize_r_eff - rows_in
- pad_cols = (cols_out - 1) * stride_h + ksize_c_eff - cols_in
-
- pad_rows, pad_cols = max(0, pad_rows), max(0, pad_cols)
+ # Indices for output start from 0.
+ output_indices_num = rows_out * cols_out * ksize_r * ksize_c
+ output_idx = array_ops.reshape(math_ops.range(output_indices_num,
+ dtype=ops.dtypes.int64),
+ (1, rows_out, cols_out, ksize_r * ksize_c))
+
+ # Construct mapping table for indices: (input -> output).
+ idx_matrix = array_ops.concat(
+ [array_ops.expand_dims(input_idx_patched, axis=-1),
+ array_ops.expand_dims(output_idx, axis=-1)],
+ axis=-1)
+ idx_map = array_ops.reshape(idx_matrix, (-1, 2))
+
+ sp_shape = (input_indices_num, output_indices_num)
+ sp_mat_full = sparse_tensor.SparseTensor(
+ idx_map,
+ array_ops.ones([output_indices_num], dtype=grad.dtype),
+ sp_shape)
+ # Remove all padding locations [0, :].
+ sp_mat = sparse_ops.sparse_slice(sp_mat_full,
+ (1, 0),
+ (input_indices_num - 1, output_indices_num))
grad_expanded = array_ops.transpose(
array_ops.reshape(
@@ -771,27 +785,6 @@ def _ExtractImagePatchesGrad(op, grad):
(1, 2, 3, 4, 0, 5))
grad_flat = array_ops.reshape(grad_expanded, (-1, batch_size * channels))
- row_steps = range(0, rows_out * stride_r, stride_r)
- col_steps = range(0, cols_out * stride_h, stride_h)
-
- idx = []
- for i in range(rows_out):
- for j in range(cols_out):
- r_low, c_low = row_steps[i] - pad_rows, col_steps[j] - pad_cols
- r_high, c_high = r_low + ksize_r_eff, c_low + ksize_c_eff
-
- idx.extend([(r * (cols_in) + c, i * (cols_out * ksize_r * ksize_c) + j *
- (ksize_r * ksize_c) + ri * (ksize_c) + ci)
- for (ri, r) in enumerate(range(r_low, r_high, rate_r))
- for (ci, c) in enumerate(range(c_low, c_high, rate_c))
- if 0 <= r and r < rows_in and 0 <= c and c < cols_in])
-
- sp_shape = (rows_in * cols_in, rows_out * cols_out * ksize_r * ksize_c)
-
- sp_mat = sparse_tensor.SparseTensor(
- array_ops.constant(idx, dtype=ops.dtypes.int64),
- array_ops.ones((len(idx),), dtype=grad.dtype), sp_shape)
-
jac = sparse_ops.sparse_tensor_dense_matmul(sp_mat, grad_flat)
grad_out = array_ops.reshape(jac, (rows_in, cols_in, batch_size, channels))
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 66bc4df18c..7bf3869ddf 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -43,6 +43,7 @@ from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops.gen_array_ops import *
from tensorflow.python.ops.gen_array_ops import reverse_v2 as reverse # pylint: disable=unused-import
from tensorflow.python.util import deprecation
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -691,28 +692,32 @@ def strided_slice(input_,
parent_name = name
- def assign(val, name=None):
- """Closure that holds all the arguments to create an assignment."""
-
- if var is None:
- raise ValueError("Sliced assignment is only supported for variables")
-
- if name is None:
- name = parent_name + "_assign"
-
- return var._strided_slice_assign(
- begin=begin,
- end=end,
- strides=strides,
- value=val,
- name=name,
- begin_mask=begin_mask,
- end_mask=end_mask,
- ellipsis_mask=ellipsis_mask,
- new_axis_mask=new_axis_mask,
- shrink_axis_mask=shrink_axis_mask)
-
- op.assign = assign
+ if not (var is None and isinstance(op, ops.EagerTensor)):
+ # TODO(b/113297051): Assigning a function to an EagerTensor seems to leak
+ # memory. Slicing variables still leaks, although ".assign" is removed for
+ # EagerTensors which are not variable slices to mitigate the issue.
+ def assign(val, name=None):
+ """Closure that holds all the arguments to create an assignment."""
+
+ if var is None:
+ raise ValueError("Sliced assignment is only supported for variables")
+
+ if name is None:
+ name = parent_name + "_assign"
+
+ return var._strided_slice_assign(
+ begin=begin,
+ end=end,
+ strides=strides,
+ value=val,
+ name=name,
+ begin_mask=begin_mask,
+ end_mask=end_mask,
+ ellipsis_mask=ellipsis_mask,
+ new_axis_mask=new_axis_mask,
+ shrink_axis_mask=shrink_axis_mask)
+
+ op.assign = assign
return op
@@ -944,6 +949,15 @@ def _get_dtype_from_nested_lists(list_or_tuple):
return None
+def _cast_nested_seqs_to_dtype(dtype):
+ def _maybe_cast(elem):
+ if ops.is_dense_tensor_like(elem):
+ if dtype != elem.dtype.base_dtype:
+ elem = gen_math_ops.cast(elem, dtype)
+ return elem
+ return _maybe_cast
+
+
def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
"""Tensor conversion function that automatically packs arguments."""
if as_ref:
@@ -953,9 +967,11 @@ def _autopacking_conversion_function(v, dtype=None, name=None, as_ref=False):
# We did not find any tensor-like objects in the nested lists, so defer to
# other conversion functions.
return NotImplemented
- if dtype is not None and dtype != inferred_dtype:
- return NotImplemented
- return _autopacking_helper(v, inferred_dtype, name or "packed")
+ if dtype is None:
+ dtype = inferred_dtype
+ elif dtype != inferred_dtype:
+ v = nest.map_structure(_cast_nested_seqs_to_dtype(dtype), v)
+ return _autopacking_helper(v, dtype, name or "packed")
# pylint: enable=invalid-name
@@ -1711,7 +1727,7 @@ def placeholder(dtype, shape=None, name=None):
@compatibility(eager)
Placeholders are not compatible with eager execution.
@end_compatibility
-
+
Args:
dtype: The type of elements in the tensor to be fed.
shape: The shape of the tensor to be fed (optional). If the shape is not
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index c5a0f2949e..6528062f3c 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -1243,3 +1244,51 @@ def assert_scalar(tensor, name=None):
raise ValueError('Expected scalar shape for %s, saw shape: %s.'
% (tensor.name, shape))
return tensor
+
+
+@tf_export('ensure_shape')
+def ensure_shape(x, shape, name=None):
+ """Updates the shape of a tensor and checks at runtime that the shape holds.
+
+ For example:
+ ```python
+ x = tf.placeholder(tf.int32)
+ print(x.shape)
+ ==> TensorShape(None)
+ y = x * 2
+ print(y.shape)
+ ==> TensorShape(None)
+
+ y = tf.ensure_shape(y, (None, 3, 3))
+ print(y.shape)
+ ==> TensorShape([Dimension(None), Dimension(3), Dimension(3)])
+
+ with tf.Session() as sess:
+ # Raises tf.errors.InvalidArgumentError, because the shape (3,) is not
+ # compatible with the shape (None, 3, 3)
+ sess.run(y, feed_dict={x: [1, 2, 3]})
+
+ ```
+
+ NOTE: This differs from `Tensor.set_shape` in that it sets the static shape
+ of the resulting tensor and enforces it at runtime, raising an error if the
+ tensor's runtime shape is incompatible with the specified shape.
+ `Tensor.set_shape` sets the static shape of the tensor without enforcing it
+ at runtime, which may result in inconsistencies between the statically-known
+ shape of tensors and the runtime value of tensors.
+
+ Args:
+ x: A `Tensor`.
+ shape: A `TensorShape` representing the shape of this tensor, a
+ `TensorShapeProto`, a list, a tuple, or None.
+ name: A name for this operation (optional). Defaults to "EnsureShape".
+
+ Returns:
+ A `Tensor`. Has the same type and contents as `x`. At runtime, raises a
+ `tf.errors.InvalidArgumentError` if `shape` is incompatible with the shape
+ of `x`.
+ """
+ if not isinstance(shape, tensor_shape.TensorShape):
+ shape = tensor_shape.TensorShape(shape)
+
+ return array_ops.ensure_shape(x, shape, name=name)
diff --git a/tensorflow/python/ops/collective_ops_test.py b/tensorflow/python/ops/collective_ops_test.py
index 9cc64ef9f6..6f3cd74406 100644
--- a/tensorflow/python/ops/collective_ops_test.py
+++ b/tensorflow/python/ops/collective_ops_test.py
@@ -53,6 +53,9 @@ class CollectiveOpTest(test.TestCase):
[0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3],
[0.2, 1.2, 2.2, 3.2, 4.2, 5.2, 6.2, 7.2])
+ def testCollectiveReduceScalar(self):
+ self._testCollectiveReduce(0.1, 0.3, 0.2)
+
def _testCollectiveBroadcast(self, t0):
group_key = 1
instance_key = 1
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index 76173e0f30..75a1a53eb7 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -24,7 +24,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
-from tensorflow.python.framework import function
+from tensorflow.python.eager import function
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.ops import gradients_impl
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index b3dacff6d6..c4e9c982b5 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -27,14 +27,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
from tensorflow.core.framework import attr_value_pb2
-from tensorflow.python import pywrap_tensorflow as c_api
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
-from tensorflow.python.util import compat
# The following modules cannot be imported directly because they cause circular
@@ -57,46 +56,27 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
name = "cond"
with ops.name_scope(name) as scope:
- # Identify if there is a caller device, & get the innermost if possible.
- # pylint: disable=protected-access
- device_funcs = ops.get_default_graph()._device_functions_outer_to_inner
- caller_device = device_funcs[-1] if device_funcs else None
-
- caller_colocation_stack = ops.get_default_graph()._colocation_stack
- caller_container = ops.get_default_graph()._container
- caller_collection_ref = ops.get_default_graph()._collections
-
with ops.name_scope(None):
# Find the outer most graph for uniquing function names.
# TODO(jpienaar): Make this work in eager mode.
graph = ops.get_default_graph()
- while isinstance(graph, _function._FuncGraph):
- graph = graph._outer_graph
+ while isinstance(graph, _function.FuncGraph):
+ graph = graph.outer_graph
true_name = graph.unique_name(("%strue" % scope).replace("/", "_"))
false_name = graph.unique_name(("%sfalse" % scope).replace("/", "_"))
- # pylint: enable=protected-access
+
true_graph = _function.func_graph_from_py_func(
- true_fn, [], [],
- name=true_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ true_name, true_fn, [], {})
false_graph = _function.func_graph_from_py_func(
- false_fn, [], [],
- name=false_name,
- device=caller_device,
- colocation_stack=caller_colocation_stack,
- collections_ref=caller_collection_ref,
- container=caller_container)
+ false_name, false_fn, [], {})
_check_same_outputs(true_graph, false_graph)
# Add inputs to true_graph and false_graph to make them match. Note that
# this modifies true_graph and false_graph.
cond_inputs = _make_inputs_match(true_graph, false_graph,
- true_graph.extra_inputs,
- false_graph.extra_inputs)
+ true_graph.external_captures,
+ false_graph.external_captures)
# Add all intermediate tensors as function outputs so they're available for
# the gradient computation.
@@ -148,8 +128,8 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
true_graph, false_graph = _get_func_graphs(op)
# Note: op.graph != ops.get_default_graph() when we are computing the gradient
# of a nested cond.
- assert true_graph._outer_graph == op.graph
- assert false_graph._outer_graph == op.graph
+ assert true_graph.outer_graph == op.graph
+ assert false_graph.outer_graph == op.graph
# Create grad functions that compute the gradient of the true/false forward
# graphs. These functions will capture tensors from the forward pass
@@ -164,14 +144,13 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# Resolve references to forward graph tensors in grad graphs and ensure
# they are in-scope, i.e., belong to one of outer graphs of the grad graph.
- true_grad_extra_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
- false_grad_extra_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
+ true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph)
+ false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph)
# Make the inputs to true_grad_graph and false_grad_graph match. Note that
# this modifies true_grad_graph and false_grad_graph.
grad_inputs = _make_inputs_match(true_grad_graph, false_grad_graph,
- true_grad_extra_inputs,
- false_grad_extra_inputs)
+ true_grad_inputs, false_grad_inputs)
# Add all intermediate tensors as function outputs so they're available for
# higher-order gradient computations.
@@ -211,8 +190,8 @@ def _get_func_graphs(if_op):
"""
def _get_func_graph_for_branch(branch_name):
"""Generates and returns a _FuncGraph for the given branch."""
- extra_inputs = if_op.inputs[1:] # First input is pred.
- input_shapes = [t.shape for t in extra_inputs]
+ inputs = if_op.inputs[1:] # First input is pred.
+ input_shapes = [t.shape for t in inputs]
func_name = if_op.get_attr(branch_name).name
fdef = if_op.graph._get_function(func_name).definition
# `if_op.graph` may not be the same as `ops.get_default_graph()` e.g.
@@ -224,9 +203,8 @@ def _get_func_graphs(if_op):
with if_op.graph.as_default():
func_graph = _function_def_to_graph.function_def_to_graph(
fdef, input_shapes)
- func_graph.extra_inputs = extra_inputs
- func_graph.extra_args = func_graph.inputs
- func_graph._captured = dict(zip(extra_inputs, func_graph.inputs))
+ func_graph.captures = collections.OrderedDict(zip(inputs,
+ func_graph.inputs))
# Set the if op so that the gradient code can use it.
func_graph._if = if_op
return func_graph
@@ -282,12 +260,12 @@ def _grad_fn(func_graph, grads):
def _create_grad_func(func_graph, grads, name):
"""Returns the _FuncGraph representation of _grad_fn."""
- return _function.func_graph_from_py_func(lambda: _grad_fn(func_graph, grads),
- [], [], name)
+ return _function.func_graph_from_py_func(
+ name, lambda: _grad_fn(func_graph, grads), [], {})
def _resolve_grad_inputs(cond_graph, grad_graph):
- """Returns the tensors to pass as `extra_inputs` to `grad_graph`.
+ """Returns the tensors to pass as inputs to `grad_graph`.
The `grad_graph` may have external references to
1. Its outer graph containing the input gradients. These references are kept
@@ -305,10 +283,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
Returns:
A list of inputs tensors to be passed to grad_graph.
"""
- new_extra_inputs = []
+ new_inputs = []
- for t in grad_graph.extra_inputs:
- if t.graph != grad_graph._outer_graph:
+ for t in grad_graph.external_captures:
+ if t.graph != grad_graph.outer_graph:
# `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
# tensor to the least common ancestor of the `cond_graph` and
# `grad_graph` so that it is "in-scope" for `grad_graph`.
@@ -316,19 +294,19 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
# common ancestor once and re-use.
assert _is_ancestor(cond_graph, t.graph)
while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function._FuncGraph)
- if t in t.graph.extra_args:
- # TODO(srbs): Consider building a map of extra_args -> extra_inputs.
- # instead of searching for `t` twice.
- t = t.graph.extra_inputs[t.graph.extra_args.index(t)]
+ assert isinstance(t.graph, _function.FuncGraph)
+ if t in t.graph.internal_captures:
+ # TODO(srbs): Consider building a map of internal_captures ->
+ # external_captures instead of searching for `t` twice.
+ t = t.graph.external_captures[t.graph.internal_captures.index(t)]
else:
# Note: All intermediate tensors are output by the If op.
# TODO(srbs): .index() calls may be expensive. Optimize.
t = t.graph._if.outputs[t.graph.outputs.index(t)]
assert _is_ancestor(grad_graph, t.graph)
- new_extra_inputs.append(t)
+ new_inputs.append(t)
- return new_extra_inputs
+ return new_inputs
def _create_new_tf_function(func_graph):
@@ -340,26 +318,9 @@ def _create_new_tf_function(func_graph):
Returns:
The name of the new TF_Function.
"""
- c_func = c_api.TF_GraphToFunction_wrapper(
- func_graph._c_graph,
- compat.as_str(func_graph.name),
- False, # append_hash_to_fn_name
- None, # opers
- [t._as_tf_output() for t in func_graph.inputs],
- [t._as_tf_output() for t in func_graph.outputs],
- [],
- None, # opts
- None) # description
- _ = c_api_util.ScopedTFFunction(c_func)
-
- # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
- # deserializing it into a Python FunctionDef, then reserializing it to create
- # a new TF_Function that we add to the graph.
- fdef = _function.function_def_from_tf_function(c_func)
- defined_func = _function._from_definition(fdef)
- defined_func._sub_functions = func_graph._functions
- defined_func.add_to_graph(func_graph._outer_graph)
-
+ func = _function._EagerDefinedFunction(
+ func_graph.name, func_graph, func_graph.inputs, func_graph.outputs, {})
+ func.add_to_graph(func_graph.outer_graph)
return func_graph.name
@@ -421,21 +382,20 @@ def _pad_params(true_graph, false_graph, true_params, false_params):
return new_true_params, new_false_inputs
-def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
- false_extra_inputs):
+def _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs):
"""Modifies true_graph and false_graph so they have the same input signature.
This method reorders and/or adds parameters to true_graph and false_graph so
- they have the same input signature, and updates the 'inputs', 'extra_inputs',
- and '_captured' fields of both graphs accordingly. It uses the input tensors
- from the outer graph to avoid duplicating shared arguments.
+ they have the same input signature, and updates the 'inputs' and 'captured'
+ fields of both graphs accordingly. It uses the input tensors from the outer
+ graph to avoid duplicating shared arguments.
Args:
true_graph: function._FuncGraph
false_graph: function._FuncGraph
- true_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ true_inputs: a list of Tensors in the outer graph. The inputs for
true_graph.
- false_extra_inputs: a list of Tensors in the outer graph. The inputs for
+ false_inputs: a list of Tensors in the outer graph. The inputs for
false_graph.
Returns:
@@ -444,12 +404,12 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
false_inputs.
"""
shared_inputs, true_only_inputs, false_only_inputs = _separate_unique_inputs(
- true_extra_inputs, false_extra_inputs)
+ true_inputs, false_inputs)
new_inputs = shared_inputs + true_only_inputs + false_only_inputs
- true_input_to_param = dict(zip(true_extra_inputs, true_graph.inputs))
- false_input_to_param = dict(zip(false_extra_inputs, false_graph.inputs))
+ true_input_to_param = dict(zip(true_inputs, true_graph.inputs))
+ false_input_to_param = dict(zip(false_inputs, false_graph.inputs))
true_graph.inputs = (
[true_input_to_param[t] for t in shared_inputs] +
@@ -462,14 +422,10 @@ def _make_inputs_match(true_graph, false_graph, true_extra_inputs,
[false_input_to_param[t] for t in false_only_inputs])
# Rewrite the _FuncGraphs' state to reflect the new inputs.
- true_graph.extra_inputs = new_inputs
- false_graph.extra_inputs = new_inputs
-
- true_graph.extra_args = true_graph.inputs
- false_graph.extra_args = false_graph.inputs
-
- true_graph._captured = dict(zip(new_inputs, true_graph.inputs))
- false_graph._captured = dict(zip(new_inputs, false_graph.inputs))
+ true_graph.captures = collections.OrderedDict(zip(new_inputs,
+ true_graph.inputs))
+ false_graph.captures = collections.OrderedDict(zip(new_inputs,
+ false_graph.inputs))
return new_inputs
@@ -506,10 +462,10 @@ def _get_grad_fn_name(func_graph):
counter = 1
has_conflict = True
while has_conflict:
- curr_graph = func_graph._outer_graph
+ curr_graph = func_graph.outer_graph
has_conflict = curr_graph._is_function(name)
- while not has_conflict and isinstance(curr_graph, _function._FuncGraph):
- curr_graph = curr_graph._outer_graph
+ while not has_conflict and isinstance(curr_graph, _function.FuncGraph):
+ curr_graph = curr_graph.outer_graph
has_conflict = curr_graph._is_function(name)
if has_conflict:
name = "%s_%s" % (base_name, counter)
@@ -534,6 +490,6 @@ def _check_same_outputs(true_graph, false_graph):
def _is_ancestor(graph, maybe_ancestor):
if maybe_ancestor == graph:
return True
- if isinstance(graph, _function._FuncGraph):
- return _is_ancestor(graph._outer_graph, maybe_ancestor)
+ if isinstance(graph, _function.FuncGraph):
+ return _is_ancestor(graph.outer_graph, maybe_ancestor)
return False
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index d1095c8954..e3c1aa3d5a 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1966,8 +1966,12 @@ def cond(pred,
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
`false_fn` must have the same non-zero number and type of outputs.
- Note that the conditional execution applies only to the operations defined in
- `true_fn` and `false_fn`. Consider the following simple program:
+ **WARNING**: Any Tensors or Operations created outside of `true_fn` and
+ `false_fn` will be executed regardless of which branch is selected at runtime.
+
+ Although this behavior is consistent with the dataflow model of TensorFlow,
+ it has frequently surprised users who expected a lazier semantics.
+ Consider the following simple program:
```python
z = tf.multiply(a, b)
@@ -1978,8 +1982,6 @@ def cond(pred,
operation will not be executed. Since `z` is needed for at least one
branch of the `cond`, the `tf.multiply` operation is always executed,
unconditionally.
- Although this behavior is consistent with the dataflow model of TensorFlow,
- it has occasionally surprised some users who expected a lazier semantics.
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
call to `cond`, and not at all during `Session.run()`). `cond`
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 7b9e7de145..6263041b8d 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -134,7 +134,10 @@ def _embedding_lookup_and_transform(params,
ids, max_norm)
if transform_fn:
result = transform_fn(result)
- return result
+ # Make sure the final result does not have colocation contraints on the
+ # params. Similar to the case np > 1 where parallel_dynamic_stitch is
+ # outside the scioe of all with ops.colocate_with(params[p]).
+ return array_ops.identity(result)
else:
# Flatten the ids. There are two cases where we need to do this.
# - There is more than one params tensor.
@@ -427,6 +430,8 @@ def embedding_lookup_sparse(params,
embeddings = embedding_lookup(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm)
+ if embeddings.dtype in (dtypes.float16, dtypes.bfloat16):
+ embeddings = math_ops.to_float(embeddings)
if not ignore_weights:
weights = sp_weights.values
if weights.dtype != embeddings.dtype:
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index fb51fbc626..561a341cf3 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -22,6 +22,7 @@ import collections
import functools
import six
+from tensorflow.python.compat import compat as fwd_compat
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -299,6 +300,7 @@ class HashTable(InitializableLookupTableBase):
self._value_shape))
return exported_keys, exported_values
+
class TableInitializerBase(object):
"""Base class for lookup table initializers."""
@@ -370,8 +372,13 @@ class KeyValueTensorInitializer(TableInitializerBase):
# Ensure a unique name when eager execution is enabled to avoid spurious
# sharing issues.
scope += str(ops.uid())
- init_op = gen_lookup_ops.initialize_table_v2(
- table.table_ref, self._keys, self._values, name=scope)
+ if fwd_compat.forward_compatible(2018, 9, 19):
+ init_op = gen_lookup_ops.lookup_table_import_v2(
+ table.table_ref, self._keys, self._values, name=scope)
+ else:
+ # To maintain forward compatibiltiy, use the old implementation.
+ init_op = gen_lookup_ops.initialize_table_v2(
+ table.table_ref, self._keys, self._values, name=scope)
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 67ea534639..9b0ab00c7a 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -618,7 +618,7 @@ def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
The operation casts `x` (in case of `Tensor`) or `x.values`
- (in case of `SparseTensor`) to `dtype`.
+ (in case of `SparseTensor` or `IndexedSlices`) to `dtype`.
For example:
@@ -637,15 +637,16 @@ def cast(x, dtype, name=None):
behavior of numpy.
Args:
- x: A `Tensor` or `SparseTensor` of numeric type. It could be
- `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`, `int64`,
- `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
- dtype: The destination type. The list of supported dtypes is the same
- as `x`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices` of numeric type. It could
+ be `uint8`, `uint16`, `uint32`, `uint64`, `int8`, `int16`, `int32`,
+ `int64`, `float16`, `float32`, `float64`, `complex64`, `complex128`,
+ `bfloat16`.
+ dtype: The destination type. The list of supported dtypes is the same as
+ `x`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` and
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` and
same type as `dtype`.
Raises:
@@ -659,6 +660,9 @@ def cast(x, dtype, name=None):
if isinstance(x, sparse_tensor.SparseTensor):
values_cast = cast(x.values, base_type, name=name)
x = sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape)
+ elif isinstance(x, ops.IndexedSlices):
+ values_cast = cast(x.values, base_type, name=name)
+ x = ops.IndexedSlices(values_cast, x.indices, x.dense_shape)
else:
# TODO(josh11b): If x is not already a Tensor, we could return
# ops.convert_to_tensor(x, dtype=dtype, ...) here, but that
@@ -711,11 +715,12 @@ def to_float(x, name="ToFloat"):
"""Casts a tensor to type `float32`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `float32`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `float32`.
Raises:
TypeError: If `x` cannot be cast to the `float32`.
@@ -728,11 +733,12 @@ def to_double(x, name="ToDouble"):
"""Casts a tensor to type `float64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `float64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `float64`.
Raises:
TypeError: If `x` cannot be cast to the `float64`.
@@ -745,11 +751,12 @@ def to_int32(x, name="ToInt32"):
"""Casts a tensor to type `int32`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `int32`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `int32`.
Raises:
TypeError: If `x` cannot be cast to the `int32`.
@@ -762,11 +769,12 @@ def to_int64(x, name="ToInt64"):
"""Casts a tensor to type `int64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `int64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `int64`.
Raises:
TypeError: If `x` cannot be cast to the `int64`.
@@ -779,11 +787,12 @@ def to_bfloat16(x, name="ToBFloat16"):
"""Casts a tensor to type `bfloat16`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `bfloat16`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `bfloat16`.
Raises:
TypeError: If `x` cannot be cast to the `bfloat16`.
@@ -796,11 +805,12 @@ def to_complex64(x, name="ToComplex64"):
"""Casts a tensor to type `complex64`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `complex64`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `complex64`.
Raises:
TypeError: If `x` cannot be cast to the `complex64`.
@@ -813,11 +823,12 @@ def to_complex128(x, name="ToComplex128"):
"""Casts a tensor to type `complex128`.
Args:
- x: A `Tensor` or `SparseTensor`.
+ x: A `Tensor` or `SparseTensor` or `IndexedSlices`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x` with type `complex128`.
+ A `Tensor` or `SparseTensor` or `IndexedSlices` with same shape as `x` with
+ type `complex128`.
Raises:
TypeError: If `x` cannot be cast to the `complex128`.
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index a648653909..e1a01ab4c3 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -27,7 +27,6 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("Conv2DBackpropInput")
@@ -977,25 +976,30 @@ def _TopKGrad(op, grad, _):
in_shape = array_ops.shape(op.inputs[0])
ind_shape = array_ops.shape(op.outputs[1])
- ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1)
+ # int32 is not supported on GPU hence up-casting
+ ind_lastdim = array_ops.gather(math_ops.cast(
+ ind_shape, dtypes.int64), array_ops.size(ind_shape) - 1)
# Flatten indices to 2D.
ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
- in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1)
+ in_lastdim = array_ops.gather(math_ops.cast(
+ in_shape, dtypes.int64), array_ops.size(in_shape) - 1)
outerdim = array_ops.shape(ind_2d)[0]
# Compute linear indices (flattened to 1D).
- ind = array_ops.reshape(ind_2d + array_ops.expand_dims(
- math_ops.range(0, outerdim * in_lastdim, in_lastdim), -1), [-1])
+ ind = array_ops.reshape(ind_2d + math_ops.cast(array_ops.expand_dims(
+ math_ops.range(0, math_ops.cast(outerdim, dtypes.int64)
+ * in_lastdim, in_lastdim), -1), dtypes.int32), [-1])
# Substitute grad to appropriate locations and fill the rest with zeros,
# finally reshaping it to the original input shape.
return [
array_ops.reshape(
- sparse_ops.sparse_to_dense(
- ind,
- array_ops.reshape(math_ops.reduce_prod(in_shape), [1]),
+ array_ops.scatter_nd(
+ array_ops.expand_dims(ind, -1),
array_ops.reshape(grad, [-1]),
- validate_indices=False), in_shape),
+ [math_ops.reduce_prod(in_shape)]
+ ),
+ in_shape),
array_ops.zeros([], dtype=dtypes.int32)
]
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index edc6e04b48..474e0bb295 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1586,7 +1586,7 @@ def leaky_relu(features, alpha=0.2, name=None):
"Rectifier Nonlinearities Improve Neural Network Acoustic Models"
AL Maas, AY Hannun, AY Ng - Proc. ICML, 2013
- http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf
+ https://ai.stanford.edu/~amaas/papers/relu_hybrid_icml2013_final.pdf
Args:
features: A `Tensor` representing preactivation values. Must be one of
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index d8d9af545f..6041e2a0c5 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -629,76 +629,12 @@ def _parse_example_raw(serialized,
Returns:
A `dict` mapping keys to `Tensor`s and `SparseTensor`s.
- Raises:
- ValueError: If sparse and dense key sets intersect, or input lengths do not
- match up.
"""
with ops.name_scope(name, "ParseExample", [serialized, names]):
- names = [] if names is None else names
- dense_defaults = collections.OrderedDict(
- ) if dense_defaults is None else dense_defaults
- sparse_keys = [] if sparse_keys is None else sparse_keys
- sparse_types = [] if sparse_types is None else sparse_types
- dense_keys = [] if dense_keys is None else dense_keys
- dense_types = [] if dense_types is None else dense_types
- dense_shapes = (
- [[]] * len(dense_keys) if dense_shapes is None else dense_shapes)
-
- num_dense = len(dense_keys)
- num_sparse = len(sparse_keys)
-
- if len(dense_shapes) != num_dense:
- raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d"
- % (len(dense_shapes), num_dense))
- if len(dense_types) != num_dense:
- raise ValueError("len(dense_types) != len(num_dense): %d vs. %d"
- % (len(dense_types), num_dense))
- if len(sparse_types) != num_sparse:
- raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d"
- % (len(sparse_types), num_sparse))
- if num_dense + num_sparse == 0:
- raise ValueError("Must provide at least one sparse key or dense key")
- if not set(dense_keys).isdisjoint(set(sparse_keys)):
- raise ValueError(
- "Dense and sparse keys must not intersect; intersection: %s" %
- set(dense_keys).intersection(set(sparse_keys)))
-
- # Convert dense_shapes to TensorShape object.
- dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
-
- dense_defaults_vec = []
- for i, key in enumerate(dense_keys):
- default_value = dense_defaults.get(key)
- dense_shape = dense_shapes[i]
- if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
- dense_shape[0].value is None):
- # Variable stride dense shape, the default value should be a
- # scalar padding value
- if default_value is None:
- default_value = ops.convert_to_tensor(
- "" if dense_types[i] == dtypes.string else 0,
- dtype=dense_types[i])
- else:
- # Reshape to a scalar to ensure user gets an error if they
- # provide a tensor that's not intended to be a padding value
- # (0 or 2+ elements).
- key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
- default_value = ops.convert_to_tensor(
- default_value, dtype=dense_types[i], name=key_name)
- default_value = array_ops.reshape(default_value, [])
- else:
- if default_value is None:
- default_value = constant_op.constant([], dtype=dense_types[i])
- elif not isinstance(default_value, ops.Tensor):
- key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
- default_value = ops.convert_to_tensor(
- default_value, dtype=dense_types[i], name=key_name)
- default_value = array_ops.reshape(default_value, dense_shape)
-
- dense_defaults_vec.append(default_value)
-
- # Finally, convert dense_shapes to TensorShapeProto
- dense_shapes = [shape.as_proto() for shape in dense_shapes]
+ (names, dense_defaults_vec, sparse_keys, sparse_types,
+ dense_keys, dense_shapes, _) = _process_raw_parameters(
+ names, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
outputs = gen_parsing_ops.parse_example(
serialized=serialized,
@@ -719,6 +655,112 @@ def _parse_example_raw(serialized,
return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
+def _process_raw_parameters(names, dense_defaults, sparse_keys, sparse_types,
+ dense_keys, dense_types, dense_shapes):
+ """Process raw parameters to params used by `gen_parsing_ops`.
+
+ Args:
+ names: A vector (1-D Tensor) of strings (optional), the names of
+ the serialized protos.
+ dense_defaults: A dict mapping string keys to `Tensor`s.
+ The keys of the dict must match the dense_keys of the feature.
+ sparse_keys: A list of string keys in the examples' features.
+ The results for these keys will be returned as `SparseTensor` objects.
+ sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+ and `tf.string` (`BytesList`) are supported.
+ dense_keys: A list of string keys in the examples' features.
+ The results for these keys will be returned as `Tensor`s
+ dense_types: A list of DTypes of the same length as `dense_keys`.
+ Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+ and `tf.string` (`BytesList`) are supported.
+ dense_shapes: A list of tuples with the same length as `dense_keys`.
+ The shape of the data for each dense feature referenced by `dense_keys`.
+ Required for any input tensors identified by `dense_keys`. Must be
+ either fully defined, or may contain an unknown first dimension.
+ An unknown first dimension means the feature is treated as having
+ a variable number of blocks, and the output shape along this dimension
+ is considered unknown at graph build time. Padding is applied for
+ minibatch elements smaller than the maximum number of blocks for the
+ given feature along this dimension.
+
+ Returns:
+ Tuple of `names`, `dense_defaults_vec`, `sparse_keys`, `sparse_types`,
+ `dense_keys`, `dense_shapes`.
+
+ Raises:
+ ValueError: If sparse and dense key sets intersect, or input lengths do not
+ match up.
+ """
+ names = [] if names is None else names
+ dense_defaults = collections.OrderedDict(
+ ) if dense_defaults is None else dense_defaults
+ sparse_keys = [] if sparse_keys is None else sparse_keys
+ sparse_types = [] if sparse_types is None else sparse_types
+ dense_keys = [] if dense_keys is None else dense_keys
+ dense_types = [] if dense_types is None else dense_types
+ dense_shapes = ([[]] * len(dense_keys)
+ if dense_shapes is None else dense_shapes)
+
+ num_dense = len(dense_keys)
+ num_sparse = len(sparse_keys)
+
+ if len(dense_shapes) != num_dense:
+ raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d" %
+ (len(dense_shapes), num_dense))
+ if len(dense_types) != num_dense:
+ raise ValueError("len(dense_types) != len(num_dense): %d vs. %d" %
+ (len(dense_types), num_dense))
+ if len(sparse_types) != num_sparse:
+ raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d" %
+ (len(sparse_types), num_sparse))
+ if num_dense + num_sparse == 0:
+ raise ValueError("Must provide at least one sparse key or dense key")
+ if not set(dense_keys).isdisjoint(set(sparse_keys)):
+ raise ValueError(
+ "Dense and sparse keys must not intersect; intersection: %s" %
+ set(dense_keys).intersection(set(sparse_keys)))
+
+ # Convert dense_shapes to TensorShape object.
+ dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
+
+ dense_defaults_vec = []
+ for i, key in enumerate(dense_keys):
+ default_value = dense_defaults.get(key)
+ dense_shape = dense_shapes[i]
+ if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
+ dense_shape[0].value is None):
+ # Variable stride dense shape, the default value should be a
+ # scalar padding value
+ if default_value is None:
+ default_value = ops.convert_to_tensor(
+ "" if dense_types[i] == dtypes.string else 0, dtype=dense_types[i])
+ else:
+ # Reshape to a scalar to ensure user gets an error if they
+ # provide a tensor that's not intended to be a padding value
+ # (0 or 2+ elements).
+ key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, [])
+ else:
+ if default_value is None:
+ default_value = constant_op.constant([], dtype=dense_types[i])
+ elif not isinstance(default_value, ops.Tensor):
+ key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
+ default_value = ops.convert_to_tensor(
+ default_value, dtype=dense_types[i], name=key_name)
+ default_value = array_ops.reshape(default_value, dense_shape)
+
+ dense_defaults_vec.append(default_value)
+
+ # Finally, convert dense_shapes to TensorShapeProto
+ dense_shapes_as_proto = [shape.as_proto() for shape in dense_shapes]
+
+ return (names, dense_defaults_vec, sparse_keys, sparse_types, dense_keys,
+ dense_shapes_as_proto, dense_shapes)
+
+
@tf_export("parse_single_example")
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index c5bc74132e..4800352ac2 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -355,6 +355,15 @@ class ResourceVariable(variables.RefVariable):
raise ValueError("initial_value must be specified.")
init_from_fn = callable(initial_value)
+ if isinstance(initial_value, ops.Tensor) and hasattr(
+ initial_value, "graph") and initial_value.graph.building_function:
+ raise ValueError("Tensor-typed variable initializers must either be "
+ "wrapped in an init_scope or callable "
+ "(e.g., `tf.Variable(lambda : "
+ "tf.truncated_normal([10, 40]))`) when building "
+ "functions. Please file a feature request if this "
+ "restriction inconveniences you.")
+
if collections is None:
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
if not isinstance(collections, (list, tuple, set)):
@@ -586,6 +595,22 @@ class ResourceVariable(variables.RefVariable):
def __bool__(self):
return bool(self.read_value())
+ def __copy__(self):
+ return self
+
+ def __deepcopy__(self, memo):
+ if not context.executing_eagerly():
+ raise NotImplementedError(
+ "__deepcopy__() is only available when eager execution is enabled.")
+ copied_variable = ResourceVariable(
+ initial_value=self.read_value(),
+ trainable=self._trainable,
+ constraint=self._constraint,
+ dtype=self._dtype,
+ name=self._shared_name + "_copy")
+ memo[self._unique_id] = copied_variable
+ return copied_variable
+
@property
def dtype(self):
"""The dtype of this variable."""
@@ -958,6 +983,9 @@ class ResourceVariable(variables.RefVariable):
return self._lazy_read(assign_op)
return assign_op
+ def __reduce__(self):
+ return (ResourceVariable, (self.numpy(),))
+
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index d990386b9a..d1b8be4df7 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -96,6 +96,60 @@ def _make_int64_tensor(value, name):
return math_ops.cast(value, dtypes.int64)
+@tf_export("sparse.expand_dims")
+def sparse_expand_dims(sp_input, axis=None, name=None):
+ """Inserts a dimension of 1 into a tensor's shape.
+
+ Given a tensor `sp_input`, this operation inserts a dimension of 1 at the
+ dimension index `axis` of `sp_input`'s shape. The dimension index `axis`
+ starts at zero; if you specify a negative number for `axis` it is counted
+ backwards from the end.
+
+ Args:
+ sp_input: A `SparseTensor`.
+ axis: 0-D (scalar). Specifies the dimension index at which to expand the
+ shape of `input`. Must be in the range `[-rank(sp_input) - 1,
+ rank(sp_input)]`.
+ name: The name of the output `SparseTensor`.
+
+ Returns:
+ A `SparseTensor` with the same data as `sp_input`, but its shape has an
+ additional dimension of size 1 added.
+ """
+ rank = sp_input.dense_shape.get_shape()[0]
+ axis = -1 if axis is None else axis
+
+ with ops.name_scope(name, default_name="expand_dims", values=[sp_input]):
+ if isinstance(axis, compat.integral_types):
+ axis = ops.convert_to_tensor(axis, name="axis", dtype=dtypes.int32)
+ elif not isinstance(axis, ops.Tensor):
+ raise TypeError("axis must be an integer value in range [-rank(sp_input)"
+ " - 1, rank(sp_input)]")
+
+ # Convert axis to a positive value if it is negative.
+ axis = array_ops.where(axis >= 0, axis, axis + rank + 1)
+
+ # Create the new column of indices for the sparse tensor by slicing
+ # the indices and inserting a new column of indices for the new dimension.
+ column_size = array_ops.shape(sp_input.indices)[0]
+ new_index = array_ops.zeros([column_size, 1], dtype=dtypes.int64)
+ indices_before = array_ops.slice(sp_input.indices, [0, 0], [-1, axis])
+ indices_after = array_ops.slice(sp_input.indices, [0, axis], [-1, -1])
+ indices = array_ops.concat(
+ [indices_before, new_index, indices_after], axis=1)
+
+ # Create the new dense shape by splicing the tensor [1] in the correct
+ # dimension of the existing shape.
+ shape_before = array_ops.slice(sp_input.dense_shape, [0], [axis])
+ shape_after = array_ops.slice(sp_input.dense_shape, [axis], [-1])
+ new_shape = ops.convert_to_tensor([1], name="new_shape", dtype=dtypes.int64)
+ shape = array_ops.concat([shape_before, new_shape, shape_after], axis=0)
+
+ # Create the output sparse tensor.
+ return sparse_tensor.SparseTensor(
+ indices=indices, values=sp_input.values, dense_shape=shape)
+
+
@tf_export("sparse.eye")
def sparse_eye(num_rows,
num_columns=None,
@@ -835,6 +889,9 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
`tf.reduce_max()`. In particular, this Op also returns a dense `Tensor`
instead of a sparse one.
+ Note: A gradient is not defined for this function, so it can't be used
+ in training models that need gradient descent.
+
Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
`keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
`reduction_axes`. If `keepdims` is true, the reduced dimensions are retained
@@ -902,6 +959,9 @@ def sparse_reduce_max_sparse(sp_input,
`tf.reduce_max()`. In contrast to SparseReduceSum, this Op returns a
SparseTensor.
+ Note: A gradient is not defined for this function, so it can't be used
+ in training models that need gradient descent.
+
Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
`keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
`reduction_axes`. If `keepdims` is true, the reduced dimensions are retained
@@ -1003,6 +1063,9 @@ def sparse_reduce_sum_sparse(sp_input,
`tf.reduce_sum()`. In contrast to SparseReduceSum, this Op returns a
SparseTensor.
+ Note: A gradient is not defined for this function, so it can't be used
+ in training models that need gradient descent.
+
Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
`keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
`reduction_axes`. If `keepdims` is true, the reduced dimensions are retained
diff --git a/tensorflow/python/ops/sparse_ops_test.py b/tensorflow/python/ops/sparse_ops_test.py
index b10c3c2187..4ee1569249 100644
--- a/tensorflow/python/ops/sparse_ops_test.py
+++ b/tensorflow/python/ops/sparse_ops_test.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import googletest
@@ -45,5 +47,35 @@ class SparseOpsTest(test_util.TensorFlowTestCase):
test_one(n, m, True)
test_one(n, m, False)
+ def testSparseExpandDims(self):
+ for rank in range(1, 4):
+ # Create a dummy input. When rank=3, shape=[2, 4, 6].
+ shape = np.arange(1, rank + 1) * 2
+ before = np.arange(np.prod(shape)).reshape(shape)
+
+ # Make entries sparse.
+ before *= np.random.binomial(1, .2, before.shape)
+ dense_shape = before.shape
+ indices = np.array(np.where(before)).T
+ values = before[before != 0]
+
+ # Try every possible valid value of axis.
+ for axis in range(-rank - 1, rank):
+ expected_after = np.expand_dims(before, axis)
+
+ for axis_as_tensor in [False, True]:
+ dense_shape_t = constant_op.constant(dense_shape, dtype=dtypes.int64)
+ indices_t = constant_op.constant(indices)
+ values_t = constant_op.constant(values)
+ before_t = sparse_tensor.SparseTensor(
+ indices=indices_t, values=values_t, dense_shape=dense_shape_t)
+
+ if axis_as_tensor:
+ axis = constant_op.constant(axis)
+
+ s = sparse_ops.sparse_expand_dims(before_t, axis)
+ d = sparse_ops.sparse_to_dense(s.indices, s.dense_shape, s.values)
+ self.assertAllEqual(self.evaluate(d), expected_after)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 46bcd68f1a..f53e06fdf9 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -40,6 +40,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
@@ -205,6 +206,42 @@ it does exist, simply return it.
"""
+_DEFAULT_USE_RESOURCE = False
+
+
+@tf_export(v1=["enable_resource_variables"])
+def enable_resource_variables():
+ """Creates resource variables by default.
+
+ Resource variables are improved versions of TensorFlow variables with a
+ well-defined memory model. Accessing a resource variable reads its value, and
+ all ops which access a specific read value of the variable are guaranteed to
+ see the same value for that tensor. Writes which happen after a read (by
+ having a control or data dependency on the read) are guaranteed not to affect
+ the value of the read tensor, and similarly writes which happen before a read
+ are guaranteed to affect the value. No guarantees are made about unordered
+ read/write pairs.
+
+ Calling tf.enable_resource_variables() lets you opt-in to this TensorFlow 2.0
+ feature.
+ """
+ global _DEFAULT_USE_RESOURCE
+ _DEFAULT_USE_RESOURCE = True
+
+
+@deprecation.deprecated(
+ None, "non-resource variables are not supported in the long term")
+@tf_export(v1=["disable_resource_variables"])
+def disable_resource_variables():
+ """Opts out of resource variables.
+
+ If your code needs tf.disable_resource_variables() to be called to work
+ properly please file a bug.
+ """
+ global _DEFAULT_USE_RESOURCE
+ _DEFAULT_USE_RESOURCE = False
+
+
class _VariableStore(object):
"""Variable store that carries a number of named Variables.
@@ -868,7 +905,7 @@ class _VariableStore(object):
# Create the variable.
if use_resource is None:
# Set the default value if unspecified.
- use_resource = False
+ use_resource = _DEFAULT_USE_RESOURCE
v = variable(
initial_value=init_val,
name=name,
@@ -2369,6 +2406,8 @@ def default_variable_creator(next_creator=None, **kwargs):
if use_resource is None:
use_resource = get_variable_scope().use_resource
+ if use_resource is None:
+ use_resource = _DEFAULT_USE_RESOURCE
use_resource = use_resource or context.executing_eagerly()
if use_resource:
return resource_variable_ops.ResourceVariable(
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index c4eff6c57b..f7da3f7d64 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -459,7 +459,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
"""
raise NotImplementedError
- def assign(self, value, use_locking=False):
+ def assign(self, value, use_locking=False, name=None, read_value=True):
"""Assigns a new value to the variable.
This is essentially a shortcut for `assign(self, value)`.
@@ -467,6 +467,9 @@ class Variable(six.with_metaclass(VariableMetaclass,
Args:
value: A `Tensor`. The new value for this variable.
use_locking: If `True`, use locking during the assignment.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -474,7 +477,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
"""
raise NotImplementedError
- def assign_add(self, delta, use_locking=False):
+ def assign_add(self, delta, use_locking=False, name=None, read_value=True):
"""Adds a value to this variable.
This is essentially a shortcut for `assign_add(self, delta)`.
@@ -482,6 +485,9 @@ class Variable(six.with_metaclass(VariableMetaclass,
Args:
delta: A `Tensor`. The value to add to this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -489,7 +495,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
"""
raise NotImplementedError
- def assign_sub(self, delta, use_locking=False):
+ def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
"""Subtracts a value from this variable.
This is essentially a shortcut for `assign_sub(self, delta)`.
@@ -497,6 +503,9 @@ class Variable(six.with_metaclass(VariableMetaclass,
Args:
delta: A `Tensor`. The value to subtract from this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
@@ -1450,7 +1459,7 @@ class RefVariable(Variable):
"""
return self._constraint
- def assign(self, value, use_locking=False):
+ def assign(self, value, use_locking=False, name=None, read_value=True):
"""Assigns a new value to the variable.
This is essentially a shortcut for `assign(self, value)`.
@@ -1458,14 +1467,21 @@ class RefVariable(Variable):
Args:
value: A `Tensor`. The new value for this variable.
use_locking: If `True`, use locking during the assignment.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the assignment has completed.
"""
- return state_ops.assign(self._variable, value, use_locking=use_locking)
+ assign = state_ops.assign(self._variable, value, use_locking=use_locking,
+ name=name)
+ if read_value:
+ return assign
+ return assign.op
- def assign_add(self, delta, use_locking=False):
+ def assign_add(self, delta, use_locking=False, name=None, read_value=True):
"""Adds a value to this variable.
This is essentially a shortcut for `assign_add(self, delta)`.
@@ -1473,14 +1489,21 @@ class RefVariable(Variable):
Args:
delta: A `Tensor`. The value to add to this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the addition has completed.
"""
- return state_ops.assign_add(self._variable, delta, use_locking=use_locking)
+ assign = state_ops.assign_add(
+ self._variable, delta, use_locking=use_locking, name=name)
+ if read_value:
+ return assign
+ return assign.op
- def assign_sub(self, delta, use_locking=False):
+ def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
"""Subtracts a value from this variable.
This is essentially a shortcut for `assign_sub(self, delta)`.
@@ -1488,12 +1511,19 @@ class RefVariable(Variable):
Args:
delta: A `Tensor`. The value to subtract from this variable.
use_locking: If `True`, use locking during the operation.
+ name: The name of the operation to be created
+ read_value: if True, will return something which evaluates to the
+ new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the subtraction has completed.
"""
- return state_ops.assign_sub(self._variable, delta, use_locking=use_locking)
+ assign = state_ops.assign_sub(
+ self._variable, delta, use_locking=use_locking, name=name)
+ if read_value:
+ return assign
+ return assign.op
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
@@ -2306,10 +2336,15 @@ class PartitionedVariable(object):
def as_tensor(self):
"""Returns the overall concatenated value as a `Tensor`.
+ The returned tensor will not inherit the control dependencies from the scope
+ where the value is used, which is similar to getting the value of
+ `Variable`.
+
Returns:
`Tensor` containing the concatenated value.
"""
- return self._concat()
+ with ops.control_dependencies(None):
+ return self._concat()
@staticmethod
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False):
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 157f2341e0..e1c233cdd9 100644..100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -105,20 +105,29 @@ limitations under the License.
}
}
+// For const parameters in a function, SWIG pretty much ignores the const.
+// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
+// Hence the 'const_cast'.
%typemap(in) const char* serialized_function_def {
- $1 = TFE_GetPythonString($input);
+ $1 = const_cast<char*>(TFE_GetPythonString($input));
}
+// For const parameters in a function, SWIG pretty much ignores the const.
+// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
+// Hence the 'const_cast'.
%typemap(in) const char* device_name {
if ($input == Py_None) {
$1 = nullptr;
} else {
- $1 = TFE_GetPythonString($input);
+ $1 = const_cast<char*>(TFE_GetPythonString($input));
}
}
+// For const parameters in a function, SWIG pretty much ignores the const.
+// See: http://www.swig.org/Doc2.0/SWIG.html#SWIG_nn13
+// Hence the 'const_cast'.
%typemap(in) const char* op_name {
- $1 = TFE_GetPythonString($input);
+ $1 = const_cast<char*>(TFE_GetPythonString($input));
}
%typemap(in) (TFE_Context*) {
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 20ff34fd8e..06d09325c8 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -75,7 +75,7 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
ValueError: If `tensor_info` is malformed.
"""
- graph = graph if graph is not None else ops.get_default_graph()
+ graph = graph or ops.get_default_graph()
def _get_tensor(name):
return graph.get_tensor_by_name(
ops.prepend_name_scope(name, import_scope=import_scope))
diff --git a/tensorflow/python/tensorflow.i b/tensorflow/python/tensorflow.i
index 26e8acd897..39174fa589 100644
--- a/tensorflow/python/tensorflow.i
+++ b/tensorflow/python/tensorflow.i
@@ -54,4 +54,5 @@ limitations under the License.
%include "tensorflow/python/grappler/item.i"
%include "tensorflow/python/grappler/tf_optimizer.i"
%include "tensorflow/python/grappler/cost_analyzer.i"
+%include "tensorflow/python/grappler/graph_analyzer.i"
%include "tensorflow/python/grappler/model_analyzer.i"
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 7001e566ce..64f0469482 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
+ "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index 73d11199d9..bc2f3516d1 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -25,6 +25,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"keras/applications/inception_resnet_v2/__init__.py",
"keras/applications/inception_v3/__init__.py",
"keras/applications/mobilenet/__init__.py",
+ "keras/applications/mobilenet_v2/__init__.py",
"keras/applications/nasnet/__init__.py",
"keras/applications/resnet50/__init__.py",
"keras/applications/vgg16/__init__.py",
diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py
index c7f414c5dc..893309f35a 100644
--- a/tensorflow/python/tools/freeze_graph.py
+++ b/tensorflow/python/tools/freeze_graph.py
@@ -89,7 +89,37 @@ def freeze_graph_with_def_protos(input_graph_def,
input_saved_model_dir=None,
saved_model_tags=None,
checkpoint_version=saver_pb2.SaverDef.V2):
- """Converts all variables in a graph and checkpoint into constants."""
+ """Converts all variables in a graph and checkpoint into constants.
+
+ Args:
+ input_graph_def: A `GraphDef`.
+ input_saver_def: A `SaverDef` (optional).
+ input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ output_node_names: The name(s) of the output nodes, comma separated.
+ restore_op_name: Unused.
+ filename_tensor_name: Unused.
+ output_graph: String where to write the frozen `GraphDef`.
+ clear_devices: A Bool whether to remove device specifications.
+ initializer_nodes: Comma separated string of initializer nodes to run before
+ freezing.
+ variable_names_whitelist: The set of variable names to convert (optional, by
+ default, all variables are converted).
+ variable_names_blacklist: The set of variable names to omit converting
+ to constants (optional).
+ input_meta_graph_def: A `MetaGraphDef` (optional),
+ input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
+ and variables (optional).
+ saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
+ load, in string format (optional).
+ checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
+ or saver_pb2.SaverDef.V2)
+
+ Returns:
+ Location of the output_graph_def.
+ """
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
@@ -271,7 +301,37 @@ def freeze_graph(input_graph,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2):
- """Converts all variables in a graph and checkpoint into constants."""
+ """Converts all variables in a graph and checkpoint into constants.
+
+ Args:
+ input_graph: A `GraphDef` file to load.
+ input_saver: A TensorFlow Saver file.
+ input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
+ input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
+ priority. Typically the result of `Saver.save()` or that of
+ `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
+ V1/V2.
+ output_node_names: The name(s) of the output nodes, comma separated.
+ restore_op_name: Unused.
+ filename_tensor_name: Unused.
+ output_graph: String where to write the frozen `GraphDef`.
+ clear_devices: A Bool whether to remove device specifications.
+ initializer_nodes: Comma separated list of initializer nodes to run before
+ freezing.
+ variable_names_whitelist: The set of variable names to convert (optional, by
+ default, all variables are converted),
+ variable_names_blacklist: The set of variable names to omit converting
+ to constants (optional).
+ input_meta_graph: A `MetaGraphDef` file to load (optional).
+ input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
+ variables (optional).
+ saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
+ load, in string format.
+ checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
+ or saver_pb2.SaverDef.V2).
+ Returns:
+ String that is the location of frozen GraphDef.
+ """
input_graph_def = None
if input_saved_model_dir:
input_graph_def = saved_model_utils.get_meta_graph_def(
diff --git a/tensorflow/python/training/adam.py b/tensorflow/python/training/adam.py
index bcbe5907d6..704ad6d3fe 100644
--- a/tensorflow/python/training/adam.py
+++ b/tensorflow/python/training/adam.py
@@ -43,15 +43,15 @@ class AdamOptimizer(optimizer.Optimizer):
Initialization:
- $$m_0 := 0 (Initialize initial 1st moment vector)$$
- $$v_0 := 0 (Initialize initial 2nd moment vector)$$
- $$t := 0 (Initialize timestep)$$
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
$$t := t + 1$$
- $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
diff --git a/tensorflow/python/training/checkpoint_management.py b/tensorflow/python/training/checkpoint_management.py
index 85f2904318..38910fb246 100644
--- a/tensorflow/python/training/checkpoint_management.py
+++ b/tensorflow/python/training/checkpoint_management.py
@@ -510,7 +510,10 @@ class CheckpointManager(object):
max_to_keep: An integer, the number of checkpoints to keep. Unless
preserved by `keep_checkpoint_every_n_hours`, checkpoints will be
deleted from the active set, oldest first, until only `max_to_keep`
- checkpoints remain.
+ checkpoints remain. If `None`, no checkpoints are deleted and everything
+ stays in the active set. Note that `max_to_keep=None` will keep all
+ checkpoint paths in memory and in the checkpoint state protocol buffer
+ on disk.
keep_checkpoint_every_n_hours: Upon removal from the active set, a
checkpoint will be preserved if it has been at least
`keep_checkpoint_every_n_hours` since the last preserved checkpoint. The
@@ -521,9 +524,10 @@ class CheckpointManager(object):
"""
self._checkpoint = checkpoint
self._save_counter_assign = None
- if not max_to_keep or max_to_keep < 0:
+ if max_to_keep is not None and max_to_keep <= 0:
raise ValueError(
- "Expected a positive integer for `max_to_max_to_keep`, got %d."
+ ("Expected a positive integer or `None` for `max_to_max_to_keep`, "
+ "got %d.")
% (max_to_keep,))
self._max_to_keep = max_to_keep
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
@@ -534,7 +538,9 @@ class CheckpointManager(object):
self._maybe_delete = collections.OrderedDict()
if recovered_state is None:
self._latest_checkpoint = None
- self._last_preserved_timestamp = current_clock
+ # Set the clock back slightly to avoid race conditions when quckly
+ # re-creating a CheckpointManager.
+ self._last_preserved_timestamp = current_clock - 1.
else:
self._latest_checkpoint = recovered_state.model_checkpoint_path
self._last_preserved_timestamp = recovered_state.last_preserved_timestamp
@@ -586,6 +592,10 @@ class CheckpointManager(object):
def _sweep(self):
"""Deletes or preserves managed checkpoints."""
+ if not self._max_to_keep:
+ # Does not update self._last_preserved_timestamp, since everything is kept
+ # in the active set.
+ return
while len(self._maybe_delete) > self._max_to_keep:
filename, timestamp = self._maybe_delete.popitem(last=False)
# Even if we're keeping this checkpoint due to
diff --git a/tensorflow/python/training/checkpoint_management_test.py b/tensorflow/python/training/checkpoint_management_test.py
index 22c2cc678a..8ef5048299 100644
--- a/tensorflow/python/training/checkpoint_management_test.py
+++ b/tensorflow/python/training/checkpoint_management_test.py
@@ -26,6 +26,7 @@ import tempfile
from google.protobuf import text_format
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
@@ -333,6 +334,49 @@ class CheckpointManagerTest(test.TestCase):
self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
@test_util.run_in_graph_and_eager_modes
+ def testKeepAll(self):
+ checkpoint = util.Checkpoint()
+ directory = os.path.join(
+ self.get_temp_dir(),
+ # Avoid sharing directories between eager and graph
+ # TODO(allenl): stop run_in_graph_and_eager_modes reusing directories
+ str(context.executing_eagerly()))
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ first_path = manager.save()
+ second_path = manager.save()
+ third_path = manager.save()
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ self.assertEqual(third_path, manager.latest_checkpoint)
+ self.assertEqual([first_path, second_path, third_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=None)
+ fourth_path = manager.save()
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ del manager
+ manager = checkpoint_management.CheckpointManager(
+ checkpoint, directory, max_to_keep=3)
+ self.assertEqual([first_path, second_path, third_path, fourth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(second_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(first_path))
+ fifth_path = manager.save()
+ self.assertEqual([third_path, fourth_path, fifth_path],
+ manager.checkpoints)
+ self.assertTrue(checkpoint_management.checkpoint_exists(fifth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(fourth_path))
+ self.assertTrue(checkpoint_management.checkpoint_exists(third_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(second_path))
+ self.assertFalse(checkpoint_management.checkpoint_exists(first_path))
+
+ @test_util.run_in_graph_and_eager_modes
@test.mock.patch.object(checkpoint_management, "time")
def testSaveRestoreState(self, mock_time):
directory = self.get_temp_dir()
@@ -345,8 +389,6 @@ class CheckpointManagerTest(test.TestCase):
mock_time.time.return_value = first_time
first_manager.save()
state = checkpoint_management.get_checkpoint_state(directory)
- self.assertEqual([first_time], state.all_model_checkpoint_timestamps)
- self.assertEqual(3., state.last_preserved_timestamp)
second_time = first_time + 3610.
second_name = os.path.join(directory, "ckpt-2")
mock_time.time.return_value = second_time
@@ -354,7 +396,6 @@ class CheckpointManagerTest(test.TestCase):
state = checkpoint_management.get_checkpoint_state(directory)
self.assertEqual([first_time, second_time],
state.all_model_checkpoint_timestamps)
- self.assertEqual(3., state.last_preserved_timestamp)
self.assertEqual([first_name, second_name], first_manager.checkpoints)
self.assertEqual(second_name, first_manager.latest_checkpoint)
del first_manager
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index f49ed5c9ff..45d217e8b1 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -199,6 +199,7 @@ class _NameBasedRestoreCoordinator(object):
for saveable in self.globally_named_object_attributes(
checkpointable):
restored_tensors = []
+ tensor_missing = False
for spec in saveable.specs:
if spec.name in self.dtype_map:
with ops.device("cpu:0"):
@@ -209,9 +210,15 @@ class _NameBasedRestoreCoordinator(object):
dtypes=[self.dtype_map[spec.name]],
name="%s_checkpoint_read" % (spec.name,))
restored_tensors.append(array_ops.identity(restored))
+ else:
+ tensor_missing = True
- saveable.restore(restored_tensors=restored_tensors,
- restored_shapes=None)
+ if not tensor_missing:
+ # Ignores values missing from the checkpoint, as with object-based
+ # restore. Status assertions can be used to check exact matches,
+ # although it's unlikely to ever happen for name-based checkpoints.
+ saveable.restore(restored_tensors=restored_tensors,
+ restored_shapes=None)
# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange
@@ -834,6 +841,11 @@ class _LoadStatus(object):
pass
@abc.abstractmethod
+ def assert_existing_objects_matched(self):
+ """Raises an exception unless existing Python objects have been matched."""
+ pass
+
+ @abc.abstractmethod
def run_restore_ops(self, session=None):
"""Runs restore ops from the checkpoint. Requires a valid checkpoint."""
pass
@@ -903,13 +915,11 @@ class CheckpointLoadStatus(_LoadStatus):
or if there are any checkpointed values which have not been matched to
Python objects.
"""
+ self.assert_existing_objects_matched()
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None)
if checkpointable is None:
raise AssertionError("Unresolved object in checkpoint: %s" % (node,))
- if checkpointable._update_uid < self._checkpoint.restore_uid: # pylint: disable=protected-access
- raise AssertionError(
- "Object not assigned a value from checkpoint: %s" % (node,))
if self._checkpoint.slot_restorations:
# Sanity check; this collection should be clear if everything has been
# restored.
@@ -920,6 +930,31 @@ class CheckpointLoadStatus(_LoadStatus):
("Unused attributes in these objects (the attributes exist in the "
"checkpoint but not in the objects): %s") % (
self._checkpoint.unused_attributes.items(),))
+ return self
+
+ def assert_existing_objects_matched(self):
+ """Asserts that checkpointable Python objects have been matched.
+
+ Note that this is a weaker assertion than `assert_consumed`. It will only
+ fail for existing Python objects which are (transitive) dependencies of the
+ root object and which do not have an entry in the checkpoint.
+
+ It will not fail, for example, if a `tf.keras.Layer` object has not yet been
+ built and so has not created any `tf.Variable` objects.
+
+ Returns:
+ `self` for chaining.
+
+ Raises:
+ AssertionError: If a Python object exists in the transitive dependencies
+ of the root object but does not have a value in the checkpoint.
+ """
+ for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
+ checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None)
+ if (checkpointable is not None
+ and checkpointable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
+ raise AssertionError(
+ "Object not assigned a value from checkpoint: %s" % (node,))
for checkpointable_object in list_objects(self._root_checkpointable):
self._checkpoint.all_python_objects.add(checkpointable_object)
unused_python_objects = (
@@ -929,7 +964,7 @@ class CheckpointLoadStatus(_LoadStatus):
raise AssertionError(
("Some Python objects were not bound to checkpointed values, likely "
"due to changes in the Python program: %s")
- % (unused_python_objects,))
+ % (list(unused_python_objects),))
return self
def run_restore_ops(self, session=None):
@@ -991,6 +1026,11 @@ class InitializationOnlyStatus(_LoadStatus):
raise AssertionError(
"No checkpoint specified (save_path=None); nothing is being restored.")
+ def assert_existing_objects_matched(self):
+ """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
+ raise AssertionError(
+ "No checkpoint specified (save_path=None); nothing is being restored.")
+
def run_restore_ops(self, session=None):
"""For consistency with `CheckpointLoadStatus`.
@@ -1064,6 +1104,15 @@ class NameBasedSaverStatus(_LoadStatus):
if checkpointable._update_uid < self._checkpoint.restore_uid:
raise AssertionError("Object not restored: %s" % (checkpointable,))
# pylint: enable=protected-access
+ return self
+
+ def assert_existing_objects_matched(self):
+ """Raises an exception if currently created objects are unmatched."""
+ # For name-based checkpoints there's no object information in the
+ # checkpoint, so there's no distinction between
+ # assert_existing_objects_matched and assert_consumed (and both are less
+ # useful since we don't touch Python objects or Python state).
+ return self.assert_consumed()
def _gather_saveable_objects(self):
"""Walk the object graph, using global names for SaveableObjects."""
@@ -1647,6 +1696,17 @@ class Checkpoint(tracking.Checkpointable):
Python objects in the dependency graph with no values in the
checkpoint. This method returns the status object, and so may be
chained with `initialize_or_restore` or `run_restore_ops`.
+ - `assert_existing_objects_matched()`:
+ Raises an exception if any existing Python objects in the dependency
+ graph are unmatched. Unlike `assert_consumed`, this assertion will
+ pass if values in the checkpoint have no corresponding Python
+ objects. For example a `tf.keras.Layer` object which has not yet been
+ built, and so has not created any variables, will pass this assertion
+ but fail `assert_consumed`. Useful when loading part of a larger
+ checkpoint into a new Python program, e.g. a training checkpoint with
+ a `tf.train.Optimizer` was saved but only the state required for
+ inference is being loaded. This method returns the status object, and
+ so may be chained with `initialize_or_restore` or `run_restore_ops`.
- `initialize_or_restore(session=None)`:
When graph building, runs variable initializers if `save_path` is
`None`, but otherwise runs restore operations. If no `session` is
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index 522167b49c..bef4bf2a16 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -384,8 +384,8 @@ class CheckpointingTests(test.TestCase):
saver = saver_lib.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- self.evaluate(v.non_dep_variable.assign(42.))
with self.test_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
self.evaluate(v.mirrored.assign(44.))
@@ -437,6 +437,9 @@ class CheckpointingTests(test.TestCase):
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
status = on_create_root.restore(save_path=save_path)
+ status.assert_existing_objects_matched()
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
on_create_model(constant_op.constant([[3.]])) # create variables
self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
self.assertAllEqual([42.],
@@ -444,6 +447,9 @@ class CheckpointingTests(test.TestCase):
on_create_model._named_dense.variables[1]))
on_create_m_bias_slot = on_create_optimizer.get_slot(
on_create_model._named_dense.variables[1], "m")
+ status.assert_existing_objects_matched()
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
# Optimizer slot variables are created when the original variable is
# restored.
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
@@ -451,6 +457,7 @@ class CheckpointingTests(test.TestCase):
self.evaluate(on_create_optimizer.variables()))
dummy_var = resource_variable_ops.ResourceVariable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value)
+ status.assert_existing_objects_matched()
status.assert_consumed()
beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
@@ -506,8 +513,11 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(0, training_continuation)
with self.assertRaises(AssertionError):
status.assert_consumed()
+ with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
else:
status.assert_consumed()
+ status.assert_existing_objects_matched()
for _ in range(num_training_steps):
session.run(train_op)
root.save(file_prefix=checkpoint_prefix, session=session)
@@ -704,11 +714,12 @@ class CheckpointingTests(test.TestCase):
load_into = LateDependencies()
status = checkpointable_utils.CheckpointableSaver(
load_into).restore(save_path)
+ status.assert_existing_objects_matched()
with self.assertRaises(AssertionError):
status.assert_consumed()
load_into.add_dep()
status.assert_consumed()
- status.run_restore_ops()
+ status.assert_existing_objects_matched().run_restore_ops()
self.assertEqual(123., self.evaluate(load_into.dep.var))
@test_util.run_in_graph_and_eager_modes
@@ -785,6 +796,7 @@ class CheckpointingTests(test.TestCase):
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = adam.AdamOptimizer(0.1)
+ slot_status.assert_existing_objects_matched()
with self.assertRaisesRegexp(AssertionError, "beta1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
@@ -884,6 +896,8 @@ class CheckpointingTests(test.TestCase):
load_root.dep_one.dep_three, name="var", initializer=0.)
with self.assertRaises(AssertionError):
status.assert_consumed()
+ with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
@test_util.run_in_graph_and_eager_modes
def testObjectsCombined(self):
@@ -907,7 +921,7 @@ class CheckpointingTests(test.TestCase):
v2 = checkpointable_utils.add_variable(
load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64)
status = checkpointable_utils.CheckpointableSaver(load_root).restore(
- save_path).assert_consumed()
+ save_path).assert_consumed().assert_existing_objects_matched()
status.run_restore_ops()
self.assertEqual(32., self.evaluate(v1))
self.assertEqual(64., self.evaluate(v2))
@@ -1239,6 +1253,8 @@ class CheckpointingTests(test.TestCase):
status.initialize_or_restore()
train_fn()
with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
+ with self.assertRaises(AssertionError):
status.assert_consumed()
# Make sure initialization doesn't clobber later restores
@@ -1451,17 +1467,27 @@ class CheckpointCompatibilityTests(test.TestCase):
if context.executing_eagerly():
with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
status.assert_consumed()
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_existing_objects_matched()
else:
# When graph building, we haven't read any keys, so we don't know
# whether the restore will be complete.
with self.assertRaisesRegexp(AssertionError, "not restored"):
status.assert_consumed()
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_existing_objects_matched()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)
status = object_saver.restore(save_path)
status.initialize_or_restore()
self._check_sentinels(root)
+ # Check that there is no error when keys are missing from the name-based
+ # checkpoint.
+ root.not_in_name_checkpoint = resource_variable_ops.ResourceVariable([1.])
+ status = object_saver.restore(save_path)
+ with self.assertRaises(AssertionError):
+ status.assert_existing_objects_matched()
def testSaveGraphLoadEager(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 4b91d1e963..177a7ddfa5 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -363,10 +363,12 @@ class ExponentialMovingAverage(object):
`GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to
`tf.global_variables()`.
- Returns an op that updates all shadow variables as described above.
+ Returns an op that updates all shadow variables from the current value of
+ their associated variables.
- Note that `apply()` can be called multiple times with different lists of
- variables.
+ Note that `apply()` can be called multiple times. When eager execution is
+ enabled each call to apply will update the variables once, so this needs to
+ be called in a loop.
Args:
var_list: A list of Variable or Tensor objects. The variables
@@ -389,31 +391,30 @@ class ExponentialMovingAverage(object):
dtypes.float64]:
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
- if var in self._averages:
- raise ValueError("Moving average already computed for: %s" % var.name)
- # For variables: to lower communication bandwidth across devices we keep
- # the moving averages on the same device as the variables. For other
- # tensors, we rely on the existing device allocation mechanism.
- with ops.init_scope():
- if isinstance(var, variables.Variable):
- avg = slot_creator.create_slot(var,
- var.initialized_value(),
- self.name,
- colocate_with_primary=True)
- # NOTE(mrry): We only add `tf.Variable` objects to the
- # `MOVING_AVERAGE_VARIABLES` collection.
- ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
- else:
- avg = slot_creator.create_zeros_slot(
- var,
- self.name,
- colocate_with_primary=(var.op.type in ["Variable",
- "VariableV2",
- "VarHandleOp"]))
- if self._zero_debias:
- zero_debias_true.add(avg)
- self._averages[var] = avg
+ if var not in self._averages:
+ # For variables: to lower communication bandwidth across devices we keep
+ # the moving averages on the same device as the variables. For other
+ # tensors, we rely on the existing device allocation mechanism.
+ with ops.init_scope():
+ if isinstance(var, variables.Variable):
+ avg = slot_creator.create_slot(var,
+ var.initialized_value(),
+ self.name,
+ colocate_with_primary=True)
+ # NOTE(mrry): We only add `tf.Variable` objects to the
+ # `MOVING_AVERAGE_VARIABLES` collection.
+ ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
+ else:
+ avg = slot_creator.create_zeros_slot(
+ var,
+ self.name,
+ colocate_with_primary=(var.op.type in ["Variable",
+ "VariableV2",
+ "VarHandleOp"]))
+ if self._zero_debias:
+ zero_debias_true.add(avg)
+ self._averages[var] = avg
with ops.name_scope(self.name) as scope:
decay = ops.convert_to_tensor(self._decay, name="decay")
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 3e85e6bfa7..fdb8d795c3 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import variable_scope
@@ -254,6 +256,25 @@ class ExponentialMovingAverageTest(test.TestCase):
self.assertEqual(1, sess.run(v0))
self.assertEqual([17.5], sess.run(v1_avg))
+ @test_util.run_in_graph_and_eager_modes
+ def testBasicEager(self):
+ v0 = variables.Variable(1.0)
+ v1 = variables.Variable(2.0)
+
+ ema = moving_averages.ExponentialMovingAverage(0.25)
+ op = ema.apply([v0, v1])
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(op)
+
+ self.evaluate(v0.assign(2.0))
+ self.evaluate(v1.assign(4.0))
+
+ self.evaluate(ema.apply([v0, v1]))
+
+ self.assertAllEqual(self.evaluate(ema.average(v0)), 1.75)
+ self.assertAllEqual(self.evaluate(ema.average(v1)), 3.5)
+
def averageVariablesNamesHelper(self, zero_debias):
with self.test_session():
v0 = variables.Variable(10.0, name="v0")
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 1b6bce2865..2304a461c1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -772,16 +772,15 @@ class Optimizer(
Returns:
A list of variables.
"""
- executing_eagerly = context.executing_eagerly()
current_graph = ops.get_default_graph()
def _from_current_graph(variable):
- if executing_eagerly:
+ if variable._in_graph_mode: # pylint: disable=protected-access
+ return variable.op.graph is current_graph
+ else:
# No variable.op in eager mode. We don't expect lots of eager graphs,
# but behavior should be consistent with graph mode.
return variable._graph_key == current_graph._graph_key # pylint: disable=protected-access
- else:
- return variable.op.graph is current_graph
optimizer_variables = [v for v in self._non_slot_variables()
if _from_current_graph(v)]
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index b46095d458..f5b2a22327 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -2853,8 +2853,8 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver = saver_module.Saver(var_list=[v])
test_dir = self.get_temp_dir()
prefix = os.path.join(test_dir, "ckpt")
- self.evaluate(v.non_dep_variable.assign(42.))
with self.test_session() as sess:
+ self.evaluate(v.non_dep_variable.assign(42.))
save_path = saver.save(sess, prefix)
self.evaluate(v.non_dep_variable.assign(43.))
saver.restore(sess, save_path)
diff --git a/tensorflow/python/util/tf_export.py b/tensorflow/python/util/tf_export.py
index 2be4dbb283..a5ac430ce7 100644
--- a/tensorflow/python/util/tf_export.py
+++ b/tensorflow/python/util/tf_export.py
@@ -136,11 +136,14 @@ class api_export(object): # pylint: disable=invalid-name
has no effect on exporting a constant.
api_name: Name of the API you want to generate (e.g. `tensorflow` or
`estimator`). Default is `tensorflow`.
+ allow_multiple_exports: Allow symbol to be exported multiple time under
+ different names.
"""
self._names = args
self._names_v1 = kwargs.get('v1', args)
self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME)
self._overrides = kwargs.get('overrides', [])
+ self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False)
def __call__(self, func):
"""Calls this decorator.
@@ -173,9 +176,10 @@ class api_export(object): # pylint: disable=invalid-name
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
if api_names_attr in func.__dict__:
- raise SymbolAlreadyExposedError(
- 'Symbol %s is already exposed as %s.' %
- (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
+ if not self._allow_multiple_exports:
+ raise SymbolAlreadyExposedError(
+ 'Symbol %s is already exposed as %s.' %
+ (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access
setattr(func, api_names_attr, names)
def export_constant(self, module_name, name):
@@ -213,4 +217,5 @@ class api_export(object): # pylint: disable=invalid-name
tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME)
-estimator_export = functools.partial(api_export, api_name=ESTIMATOR_API_NAME)
+estimator_export = functools.partial(
+ api_export, api_name=ESTIMATOR_API_NAME, allow_multiple_exports=True)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 61249d664b..562bbdcfeb 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -470,12 +470,14 @@ void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
// with appropriate error and sets `is_type_error` to true iff
// the error to be raised should be TypeError.
-bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
- string* error_msg, bool* is_type_error) {
+bool AssertSameStructureHelper(
+ PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
+ bool* is_type_error,
+ const std::function<int(PyObject*)>& is_sequence_helper) {
DCHECK(error_msg);
DCHECK(is_type_error);
- const bool is_seq1 = IsSequence(o1);
- const bool is_seq2 = IsSequence(o2);
+ const bool is_seq1 = is_sequence_helper(o1);
+ const bool is_seq2 = is_sequence_helper(o2);
if (PyErr_Occurred()) return false;
if (is_seq1 != is_seq2) {
string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
@@ -487,7 +489,9 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return true;
}
- // Got to scalars, so finished checking. Structures are the same.
+ // Got to objects that are considered non-sequences. Note that in tf.data
+ // use case lists and sparse_tensors are not considered sequences. So finished
+ // checking, structures are the same.
if (!is_seq1) return true;
if (check_types) {
@@ -586,7 +590,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return false;
}
bool no_internal_errors = AssertSameStructureHelper(
- v1, v2, check_types, error_msg, is_type_error);
+ v1, v2, check_types, error_msg, is_type_error, is_sequence_helper);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
@@ -759,7 +763,32 @@ PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
string error_msg;
bool is_type_error = false;
- AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
+ AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
+ IsSequenceHelper);
+ if (PyErr_Occurred()) {
+ // Don't hide Python exceptions while checking (e.g. errors fetching keys
+ // from custom mappings).
+ return nullptr;
+ }
+ if (!error_msg.empty()) {
+ PyErr_SetString(
+ is_type_error ? PyExc_TypeError : PyExc_ValueError,
+ tensorflow::strings::StrCat(
+ "The two structures don't have the same nested structure.\n\n",
+ "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
+ PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
+ .c_str());
+ return nullptr;
+ }
+ Py_RETURN_NONE;
+}
+
+PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
+ bool check_types) {
+ string error_msg;
+ bool is_type_error = false;
+ AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
+ IsSequenceForDataHelper);
if (PyErr_Occurred()) {
// Don't hide Python exceptions while checking (e.g. errors fetching keys
// from custom mappings).
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index f15ebb6efe..343605285e 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -144,16 +144,20 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
// 1. It removes support for lists as a level of nesting in nested structures.
// 2. It adds support for `SparseTensorValue` as an atomic element.
-// IsSequence specialized for the data package. Additional comments about
-// difference in functionality can be found in nest.py in tensorflow.data.util
-// and in the comments for Flatten above.
+// IsSequence specialized for `tf.data`. Additional comments about
+// difference in functionality can be found in nest.py in
+// `tensorflow.python.data.util` and in the comments for Flatten above.
bool IsSequenceForData(PyObject* o);
-// IsSequence specialized for the data package. Additional comments about
-// difference in functionality can be found in nest.py in tensorflow.data.util
-// and in the comments for Flatten above.
+// Flatten specialized for `tf.data`. Additional comments about
+// difference in functionality can be found in nest.py in
+// `tensorflow.python.data.util` and in the comments for Flatten above.
PyObject* FlattenForData(PyObject* nested);
+// AssertSameStructure specialized for `tf.data`.
+PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
+ bool check_types);
+
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 8d9f9615d7..6d336ac39d 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -110,6 +110,9 @@ Raises:
%unignore tensorflow::swig::FlattenForData;
%noexception tensorflow::swig::FlattenForData;
+%unignore tensorflow::swig::AssertSameStructureForData;
+%noexception tensorflow::swig::AssertSameStructureForData;
+
%include "tensorflow/python/util/util.h"
%unignoreall